In [29]:
import jax
import jax.numpy as jnp
from jax.nn import softmax
from jax.scipy.special import logsumexp
from jax import tree

In [30]:
#### RECURRING HELPERS FOR ROTATION MATRIX

def get_computational_basis_vectors(num_qubits: int) -> jnp.ndarray:
    indices = jnp.arange(2 ** num_qubits, dtype=jnp.uint32)  # shape (2**n,)
    powers = 2 ** jnp.arange(num_qubits - 1, -1, -1, dtype=jnp.uint32)  # shape (n,)
    bits = (indices[:, None] & powers) > 0  # shape (2**n, n), bool
    return bits.astype(jnp.float32)

def construct_rotation_matrix(measurement_basis: tuple[int, ...]) -> jnp.ndarray:
    SQRT2 = jnp.sqrt(2.0)
    single_qubit_rotation_matrices = [
        jnp.array([[1, 0], [0, 1]], dtype=jnp.complex64),               # Z
        jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64) / SQRT2,      # X
        jnp.array([[1, -1j], [1j, -1]], dtype=jnp.complex64) / SQRT2    # Y
    ]

    rotation_matrix = jnp.array([1.0+0j], dtype=jnp.complex64).reshape((1, 1))
    for idx in measurement_basis:
        rotation_matrix = jnp.kron(rotation_matrix, single_qubit_rotation_matrices[idx])
    return rotation_matrix # (2**n, 2**n)

def bitstring_to_int(bitstring: jnp.ndarray) -> jnp.ndarray:
    powers = 2 ** jnp.arange(bitstring.shape[-1] - 1, -1, -1)
    return jnp.sum(bitstring * powers, axis=-1).astype(jnp.int32)

def multiply_with_leaf(factor, leaf):
    if leaf.ndim == 1:
        return factor * leaf
    else:
        return factor[:, None] * leaf


#### MOCKING FREE ENERGY FUNCTIONS

def dummy_free_energy(sigma: jnp.ndarray, params: dict) -> jnp.ndarray:
    return jnp.dot(sigma, params["W"]) + params["b"]  # (2**n,)


#### TEST VARIABLES

measurement_basis = (1, 0, 2, 1)  # X Z Y X
measurements = jnp.array([
    [0,0,0,0],
    [1,1,1,1],
    [0,1,0,1],
    [1,0,1,0],
], dtype=jnp.int32)

params_lambda = {
    "W": jnp.linspace(0.05, 0.15, 4, dtype=jnp.float32),
    "b": jnp.array(0.3, dtype=jnp.float32)
}
params_mu = {
    "W": jnp.linspace(0.05, 0.15, 4, dtype=jnp.float32),
    "b": jnp.array(0.3, dtype=jnp.float32)
}

In [31]:
#### GRADIENTS FROM AUTODIFF OF LOSS FUNCTION

def loss_fn(
        measurements: jnp.ndarray,
        measurement_basis: tuple[int, ...],
        params_lambda: dict,
        params_mu: dict) -> jnp.ndarray:

    computational_basis_vectors = get_computational_basis_vectors(measurements.shape[1])  # (2**n, n)
    free_energy_lambda = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_lambda) # (2**n,)
    free_energy_mu  = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_mu) # (2**n,)

    unnormalized_amplitude = jnp.exp(-0.5 * free_energy_lambda - 0.5j * free_energy_mu)  # (2**n,)

    rotation_matrix = construct_rotation_matrix(measurement_basis)  # (2**n, 2**n)

    def get_log_probability(measurement):
        idx = bitstring_to_int(measurement)
        unnormalized_rotated_amplitude = jnp.vdot(rotation_matrix[idx], unnormalized_amplitude)
        log_probability = jnp.log(jnp.abs(unnormalized_rotated_amplitude)**2 + 1e-30)
        return log_probability

    log_probs = jax.vmap(get_log_probability)(measurements)
    loss = -jnp.mean(log_probs)
    return loss


autodiff_grad_fn = jax.grad(loss_fn, argnums=3)

autodiff_grad = autodiff_grad_fn(measurements, measurement_basis, params_lambda, params_mu)
autodiff_grad

{'W': Array([-9.9958172e+00, -3.8146973e-06,  2.9100418e-02, -3.3208370e+00],      dtype=float32),
 'b': Array(-5.722046e-06, dtype=float32)}

In [41]:
#### EXPLICIT GRADIENTS

def grad_fn(measurements, basis, params_lambda, params_mu):
    computational_basis_vectors = get_computational_basis_vectors(measurements.shape[1])

    free_energy_lambda = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_lambda)
    free_energy_mu = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_mu)
    free_energy_mu_grads = jax.vmap(lambda s: jax.grad(lambda p: dummy_free_energy(s, p))(params_mu))(computational_basis_vectors)

    rotation_matrix = construct_rotation_matrix(basis)

    def per_sample(bits):
        idx = bitstring_to_int(bits)
        rotated_exponent = jnp.log(rotation_matrix[idx]) - 0.5 * free_energy_lambda - 0.5j * free_energy_mu
        rotated_exponent = jnp.where(jnp.isfinite(rotated_exponent), rotated_exponent, -1e30 + 0j)  # safety
        gradient_weights = jnp.imag(softmax(rotated_exponent))

        def apply_to_leaf(leaf):
            if leaf.ndim == 1:             # bias gradient
                return -jnp.sum(gradient_weights * leaf)
            else:                          # weight gradient
                return -jnp.sum(gradient_weights[:, None] * leaf, axis=0)

        return tree.map(apply_to_leaf, free_energy_mu_grads)

    grads_batch = jax.vmap(per_sample)(measurements)         # stacked pytree
    return tree.map(lambda x: jnp.mean(x, 0), grads_batch)



explicit_grad = grad_fn(measurements, measurement_basis, params_lambda, params_mu)

print("explicit :", explicit_grad)
print("autodiff :", autodiff_grad)
print("close    :", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, rtol=1e-3, atol=1e-5), explicit_grad, autodiff_grad)))

explicit : {'W': Array([-9.9956303e+00, -1.7881393e-06,  2.9097542e-02, -3.3207827e+00],      dtype=float32), 'b': Array(2.1606684e-07, dtype=float32)}
autodiff : {'W': Array([-9.9958172e+00, -3.8146973e-06,  2.9100418e-02, -3.3208370e+00],      dtype=float32), 'b': Array(-5.722046e-06, dtype=float32)}
close    : True
