In [2]:
import jax
import jax.numpy as jnp
from jax.nn import softmax
from jax.scipy.special import logsumexp
from jax import tree

In [3]:
#### RECURRING HELPERS FOR ROTATION MATRIX

def get_computational_basis_vectors(num_qubits: int) -> jnp.ndarray:
    indices = jnp.arange(2 ** num_qubits, dtype=jnp.uint32)  # shape (2**n,)
    powers = 2 ** jnp.arange(num_qubits - 1, -1, -1, dtype=jnp.uint32)  # shape (n,)
    bits = (indices[:, None] & powers) > 0  # shape (2**n, n), bool
    return bits.astype(jnp.float32)

def construct_rotation_matrix(measurement_basis: tuple[int, ...]) -> jnp.ndarray:
    SQRT2 = jnp.sqrt(2.0)
    single_qubit_rotation_matrices = [
        jnp.array([[1, 0], [0, 1]], dtype=jnp.complex64),               # Z
        jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64) / SQRT2,      # X
        jnp.array([[1, -1j], [1j, -1]], dtype=jnp.complex64) / SQRT2    # Y
    ]

    rotation_matrix = jnp.array([1.0+0j], dtype=jnp.complex64).reshape((1, 1))
    for idx in measurement_basis:
        rotation_matrix = jnp.kron(rotation_matrix, single_qubit_rotation_matrices[idx])
    return rotation_matrix # (2**n, 2**n)

def bitstring_to_int(bitstring: jnp.ndarray) -> jnp.ndarray:
    powers = 2 ** jnp.arange(bitstring.shape[-1] - 1, -1, -1)
    return jnp.sum(bitstring * powers, axis=-1).astype(jnp.int32)

def multiply_with_leaf(factor, leaf):
    if leaf.ndim == 1:
        return factor * leaf
    else:
        return factor[:, None] * leaf


#### MOCKING FREE ENERGY FUNCTIONS

def dummy_free_energy(sigma: jnp.ndarray, params: dict) -> jnp.ndarray:
    return jnp.dot(sigma, params["W"]) + params["b"]  # (2**n,)


#### TEST VARIABLES

measurement_basis = (1, 0, 2, 1)  # X Z Y X
measurements = jnp.array([
    [0,0,0,0],
    [1,1,1,1],
    [0,1,0,1],
    [1,0,1,0],
], dtype=jnp.int32)

params_lambda = {
    "W": jnp.linspace(0.05, 0.15, 4, dtype=jnp.float32),
    "b": jnp.array(0.3, dtype=jnp.float32)
}
params_mu = {
    "W": jnp.linspace(0.05, 0.15, 4, dtype=jnp.float32),
    "b": jnp.array(0.3, dtype=jnp.float32)
}

In [30]:
#### GRADIENTS FROM AUTODIFF OF LOSS FUNCTION

def rotated_log_probs(U_row: jnp.ndarray, F_lambda: jnp.ndarray, F_mu: jnp.ndarray) -> jnp.ndarray:
    exponent = -0.5 * F_lambda - 0.5j * F_mu                     # (2**n,)
    values = U_row * jnp.exp(exponent)                          # Complex vector

    abs_vals = jnp.abs(values)
    max_log = jnp.max(jnp.log(abs_vals + 1e-30))                # Scalar for stability
    scaled_values = values * jnp.exp(-max_log)                  # Scale down before summing

    return 2 * (max_log + jnp.log(jnp.abs(jnp.sum(scaled_values)) + 1e-30))

def loss_fn(
        measurements: jnp.ndarray,
        measurement_basis: tuple[int, ...],
        params_lambda: dict,
        params_mu: dict) -> jnp.ndarray:

    computational_basis_vectors = get_computational_basis_vectors(measurements.shape[1])  # (2**n, n)
    free_energy_lambda = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_lambda) # (2**n,)
    free_energy_mu  = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_mu) # (2**n,)

    rotation_matrix = construct_rotation_matrix(measurement_basis)  # (2**n, 2**n)

    get_log_prob = lambda m: rotated_log_probs(rotation_matrix[bitstring_to_int(m)], free_energy_lambda, free_energy_mu)

    log_probs = jax.vmap(get_log_prob)(measurements)
    loss = -jnp.mean(log_probs)
    return loss


autodiff_grad_fn = jax.grad(loss_fn, argnums=3)

autodiff_grad = autodiff_grad_fn(measurements, measurement_basis, params_lambda, params_mu)
autodiff_grad

{'W': Array([-9.99688721,  0.00006104,  0.01806641, -5.99365234,  0.00006104,
        -4.72842407,  0.02926636,  0.00003052,  0.00006104,  0.00006104],      dtype=float32),
 'b': Array( 0.00006294, dtype=float32)}

In [4]:
#### EXPLICIT GRADIENTS

def grad_fn(measurements, basis, params_lambda, params_mu):
    computational_basis_vectors = get_computational_basis_vectors(measurements.shape[1])

    free_energy_lambda = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_lambda)
    free_energy_mu = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_mu)
    free_energy_mu_grads = jax.vmap(lambda s: jax.grad(lambda p: dummy_free_energy(s, p))(params_mu))(computational_basis_vectors)

    rotation_matrix = construct_rotation_matrix(basis)

    def per_sample(bits):
        idx = bitstring_to_int(bits)
        rotated_exponent = jnp.log(rotation_matrix[idx]) - 0.5 * free_energy_lambda - 0.5j * free_energy_mu
        rotated_exponent = jnp.where(jnp.isfinite(rotated_exponent), rotated_exponent, -1e30 + 0j)  # safety
        gradient_weights = jnp.imag(softmax(rotated_exponent))

        def apply_to_leaf(leaf):
            if leaf.ndim == 1:             # bias gradient
                return -jnp.sum(gradient_weights * leaf)
            else:                          # weight gradient
                return -jnp.sum(gradient_weights[:, None] * leaf, axis=0)

        return tree.map(apply_to_leaf, free_energy_mu_grads)

    grads_batch = jax.vmap(per_sample)(measurements)         # stacked pytree
    return tree.map(lambda x: jnp.mean(x, 0), grads_batch)



explicit_grad = grad_fn(measurements, measurement_basis, params_lambda, params_mu)

print("explicit :", explicit_grad)
print("autodiff :", autodiff_grad)
print("close    :", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, atol=1e-4), explicit_grad, autodiff_grad)))

explicit : {'W': Array([-9.9956303e+00, -1.7881393e-06,  2.9097542e-02, -3.3207827e+00],      dtype=float32), 'b': Array(2.1606684e-07, dtype=float32)}
autodiff : {'W': Array([-9.9958706e+00,  5.7220459e-06,  2.9106140e-02, -3.3208485e+00],      dtype=float32), 'b': Array(1.9073486e-06, dtype=float32)}
close    : False


In [5]:
def p_lambda(sigma, params_lambda):
    free_energy_lambda = dummy_free_energy(sigma, params_lambda)
    return jnp.exp(-free_energy_lambda)

def phi_mu(sigma, params_mu):
    free_energy_mu = dummy_free_energy(sigma, params_mu)
    return -free_energy_mu

def paper_loss_fn(
        measurements: jnp.ndarray,
        measurement_basis: tuple[int, ...],
        params_lambda: dict,
        params_mu: dict) -> jnp.ndarray:

    computational_basis_vectors = get_computational_basis_vectors(measurements.shape[1])  # (2**n, n)

    p_lambda_values = jax.vmap(p_lambda, (0, None))(computational_basis_vectors, params_lambda)  # (2**n,)
    phi_mu_values = jax.vmap(phi_mu, (0, None))(computational_basis_vectors, params_mu)  # (2**n,)

    rotation_matrix = construct_rotation_matrix(measurement_basis)  # (2**n, 2**n)

    def get_log_probability(measurement):
        idx = bitstring_to_int(measurement)
        rotation_vector = rotation_matrix[idx]

        sqrt_p_lambda = jnp.sqrt(p_lambda_values)
        exp_phi_mu = jnp.exp(1j * phi_mu_values / 2)

        rotated_amp = jnp.vdot(rotation_vector, sqrt_p_lambda * exp_phi_mu)

        log_probability = jnp.log(jnp.abs(rotated_amp))
        return log_probability

    log_probs = []
    for measurement in measurements:
        contribution = get_log_probability(measurement)
        # plus the complex conjugated part
        log_probs.append(contribution + contribution.conj())

    log_probs = jnp.array(log_probs)

    loss = -jnp.mean(log_probs)
    return loss

import jax.numpy as jnp
jnp.set_printoptions(suppress=True, formatter={'float_kind': '{: .8f}'.format})

autodiff_grad_fn_paper = jax.grad(paper_loss_fn, argnums=3)

autodiff_grad_paper = autodiff_grad_fn_paper(measurements, measurement_basis, params_lambda, params_mu)
print("explicit (mine) :", explicit_grad)
print("autodiff (mine) :", autodiff_grad)
print("autodiff (paper):", autodiff_grad_paper)

print("close (paper):", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, atol=1e-4), explicit_grad, autodiff_grad_paper)))

explicit (mine) : {'W': Array([-9.99563026, -0.00000179,  0.02909754, -3.32078266], dtype=float32), 'b': Array( 0.00000022, dtype=float32)}
autodiff (mine) : {'W': Array([-9.99587059,  0.00000572,  0.02910614, -3.32084846], dtype=float32), 'b': Array( 0.00000191, dtype=float32)}
autodiff (paper): {'W': Array([-9.99580383, -0.00000572,  0.02909851, -3.32083893], dtype=float32), 'b': Array(-0.00000381, dtype=float32)}
close (paper): True


In [31]:
# testing how big we cna go
measurement_basis = (1, 0, 2, 1, 0, 1, 2, 0, 0, 0)  # X Z Y X Z Y X Z
measurements = jnp.array([
    [0,0,0,0,0,0,0,0,0,0],
    [1,1,1,1,1,1,1,1,1,1],
    [0,1,0,1,0,1,0,1,0,1],
    [1,0,1,0,1,0,1,0,1,0],
    [0,0,1,1,0,0,1,1,0,0],
    [1,1,0,0,1,1,0,0,1,1],
], dtype=jnp.int32)

params_lambda = {
    "W": jnp.linspace(0.05, 0.15, 10, dtype=jnp.float32),
    "b": jnp.array(0.3, dtype=jnp.float32)
}

params_mu = {
    "W": jnp.linspace(0.05, 0.15, 10, dtype=jnp.float32),
    "b": jnp.array(0.3, dtype=jnp.float32)
}

autodiff_grad_paper_fn_large = jax.grad(paper_loss_fn, argnums=3)
%time autodiff_grad_paper_large = autodiff_grad_paper_fn_large(measurements, measurement_basis, params_lambda, params_mu)

CPU times: user 32 ms, sys: 1.46 ms, total: 33.5 ms
Wall time: 31.9 ms


In [32]:
# explicit
%time explicit_grad_large = grad_fn(measurements, measurement_basis, params_lambda, params_mu)

CPU times: user 15 ms, sys: 3.17 ms, total: 18.2 ms
Wall time: 14.9 ms


In [33]:
autodiff_grad_fn_large = jax.grad(loss_fn, argnums=3)
%time autodiff_grad_large = autodiff_grad_fn_large(measurements, measurement_basis, params_lambda, params_mu)

CPU times: user 21.2 ms, sys: 3.81 ms, total: 25 ms
Wall time: 19.8 ms


In [34]:
# print out gradients of all three larger ones
print("Explicit Grad (mine) :", explicit_grad_large)
print("Autodiff Loss (mine) :", autodiff_grad_large)
print("Autodiff Loss (paper):", autodiff_grad_paper_large)

# check pairwise closeness
abs_tol = 1e-1
print(f"Within tol. {abs_tol} (explicit mine vs autodiff paper):", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, atol=abs_tol), explicit_grad_large, autodiff_grad_paper_large)))
print(f"Within tol. {abs_tol} (explicit mine vs autodiff mine):", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, atol=abs_tol), explicit_grad_large, autodiff_grad_large)))

Explicit Grad (mine) : {'W': Array([-10.01608562, -0.00000151,  0.01833252, -6.00497437, -0.00000123,
       -4.73766327,  0.02894065,  0.00000006, -0.00000123, -0.00000151],      dtype=float32), 'b': Array(-0.00004204, dtype=float32)}
Autodiff Loss (mine) : {'W': Array([-9.99688721,  0.00006104,  0.01806641, -5.99365234,  0.00006104,
       -4.72842407,  0.02926636,  0.00003052,  0.00006104,  0.00006104],      dtype=float32), 'b': Array( 0.00006294, dtype=float32)}
Autodiff Loss (paper): {'W': Array([-10.00207520, -0.00018311,  0.01821899, -5.99688721, -0.00009155,
       -4.73104858,  0.02920532, -0.00021362, -0.00009155, -0.00018311],      dtype=float32), 'b': Array(-0.00012445, dtype=float32)}
Within tol. 0.1 (explicit mine vs autodiff paper): True
Within tol. 0.1 (explicit mine vs autodiff mine): True


In [4]:
#### GRADIENTS FROM AUTODIFF OF LOSS FUNCTION

def rotated_log_probs(U_row: jnp.ndarray, F_lambda: jnp.ndarray, F_mu: jnp.ndarray) -> jnp.ndarray:
    exponent = -0.5 * F_lambda - 0.5j * F_mu                     # (2**n,)
    values = U_row * jnp.exp(exponent)                           # Complex vector

    abs_vals = jnp.abs(values)
    max_log = jnp.max(jnp.log(abs_vals + 1e-30))                 # Scalar for stability
    scaled_values = values * jnp.exp(-max_log)                   # Scale down before summing

    return 2 * (max_log + jnp.log(jnp.abs(jnp.sum(scaled_values)) + 1e-30))


def single_measurement_log_prob(
        measurement: jnp.ndarray,
        rotation_matrix: jnp.ndarray,
        F_lambda: jnp.ndarray,
        F_mu: jnp.ndarray
) -> jnp.ndarray:
    """
    Compute ln|∑_σ U_b(σ^{[b]},σ) e^{-½Fλ(σ) - i/2 Fμ(σ)}|^2
    for a single bitstring 'measurement'.
    """
    idx = bitstring_to_int(measurement)                          # int index in [0,2**n)
    U_row = rotation_matrix[idx]                                 # (2**n,)
    return rotated_log_probs(U_row, F_lambda, F_mu)


def loss_fn(
        measurements: jnp.ndarray,
        measurement_basis: tuple[int, ...],
        params_lambda: dict,
        params_mu: dict
) -> jnp.ndarray:
    # build all 2^n basis vectors and their free energies
    n = measurements.shape[1]
    comp_basis = get_computational_basis_vectors(n)              # (2**n, n)
    F_lambda = jax.vmap(dummy_free_energy, (0, None))(comp_basis, params_lambda)  # (2**n,)
    F_mu     = jax.vmap(dummy_free_energy, (0, None))(comp_basis, params_mu)      # (2**n,)

    # rotation for this basis
    R = construct_rotation_matrix(measurement_basis)             # (2**n, 2**n)

    # per‐measurement log‐probabilities
    log_probs = jax.vmap(
        lambda m: single_measurement_log_prob(m, R, F_lambda, F_mu)
    )(measurements)                                              # (batch,)

    return -jnp.mean(log_probs)


# get gradient w.r.t. the 'mu' parameters
autodiff_grad_fn = jax.grad(loss_fn, argnums=3)
autodiff_grad    = autodiff_grad_fn(
    measurements, measurement_basis, params_lambda, params_mu
)

print("Autodiff Grad:", autodiff_grad)

Autodiff Grad: {'W': Array([-9.9958706e+00,  5.7220459e-06,  2.9106140e-02, -3.3208485e+00],      dtype=float32), 'b': Array(1.9073486e-06, dtype=float32)}
