In [27]:
%reload_ext autoreload
%autoreload 2

from lib.basis import get_computational_basis_vectors, construct_rotation_matrix
from lib.formatting import bitstring_to_int

####

import jax
from jax.nn import softmax
from jax import tree

import jax.numpy as jnp
jnp.set_printoptions(suppress=True, formatter={'float_kind': '{: .8f}'.format})

In [34]:
#### DUMMY FREE ENERGY AND PYTREE HELPERS

def dummy_free_energy(sigma: jnp.ndarray, params: dict) -> jnp.ndarray:
    return jnp.dot(sigma, params["W"]) + params["b"]  # (2**n,)

def multiply_with_leaf(factor, leaf):
    if leaf.ndim == 1:
        return factor * leaf
    else:
        return factor[:, None] * leaf


#### TEST VARIABLES

basis = jnp.array([1, 0, 2, 1, 0, 1, 2, 0, 0, 0])  # X Z Y X Z Y X Z Z Z
measurements = jnp.array([
    [0,0,0,0,0,0,0,0,0,0],
    [1,1,1,1,1,1,1,1,1,1],
    [0,1,0,1,0,1,0,1,0,1],
    [1,0,1,0,1,0,1,0,1,0],
    [0,0,1,1,0,0,1,1,0,0],
    [1,1,0,0,1,1,0,0,1,1],
], dtype=jnp.int32)

basis_local = jnp.array([0, 0, 0, 0, 0, 1, 2, 0, 0, 0])  # Z Z Z Z Z X Y Z Z Z Z
measurements_local = jnp.array([
    [0,0,0,0,0,0,0,0,0,0],
    [1,1,1,1,1,1,1,1,1,1],
    [0,1,0,1,0,1,0,1,0,1],
    [1,0,1,0,1,0,1,0,1,0],
    [0,0,1,1,0,0,1,1,0,0],
    [1,1,0,0,1,1,0,0,1,1],
], dtype=jnp.int32)

params_lambda = {
    "W": jnp.linspace(0.05, 0.15, 10, dtype=jnp.float32),
    "b": jnp.array(0.3, dtype=jnp.float32)
}

params_mu = {
    "W": jnp.linspace(0.05, 0.15, 10, dtype=jnp.float32),
    "b": jnp.array(0.3, dtype=jnp.float32)
}

In [29]:
#### GRADIENTS USING AUTODIFF OF MY DERIVATION (STUPID VERSION)


def rotated_log_prob_vanilla(rotation_weights, free_energy_lambda, free_energy_mu):
    computational_amplitudes = jnp.exp(-0.5 * free_energy_lambda) * jnp.exp(-0.5j * free_energy_mu) # unnormalized
    rotated_amplitude = jnp.vdot(rotation_weights, computational_amplitudes)
    rotated_log_prob = jnp.log(jnp.abs(rotated_amplitude) ** 2 + 1e-30)
    return rotated_log_prob


def loss_fn_mine_vanilla(measurements, basis, params_lambda, params_mu):

    # get the free energies for all computational basis vectors to construct the full state vector
    computational_basis_vectors = get_computational_basis_vectors(measurements.shape[1])
    free_energy_lambda = dummy_free_energy(computational_basis_vectors, params_lambda)
    free_energy_mu = dummy_free_energy(computational_basis_vectors, params_mu)

    rotation_matrix = construct_rotation_matrix(basis)          # (2**n, 2**n)

    get_log_prob = lambda m: rotated_log_prob_vanilla(rotation_matrix[bitstring_to_int(m)], free_energy_lambda, free_energy_mu)
    log_probs = jax.vmap(get_log_prob)(measurements)

    loss = -jnp.mean(log_probs)
    return loss


autodiff_grad_fn_mine = jax.grad(loss_fn_mine_vanilla, argnums=3)

autodiff_grad_mine = autodiff_grad_fn_mine(measurements, basis, params_lambda, params_mu)
autodiff_grad_mine

{'W': Array([-9.99774170, -0.00006104,  0.01800537, -5.99432373,  0.00003052,
        -4.72903442,  0.02932739,  0.00006104,  0.00003052, -0.00006104],      dtype=float32),
 'b': Array( 0.00002766, dtype=float32)}

In [30]:
#### GRADIENTS USING AUTODIFF OF MY DERIVATION (STABLE VERSION FOR FISHER)


def rotated_log_prob_stable(rotation_weights, free_energy_lambda, free_energy_mu):
    computational_amplitudes = jnp.exp(-0.5 * free_energy_lambda) * jnp.exp(-0.5j * free_energy_mu)
    rotated_amplitude_contributions = rotation_weights * computational_amplitudes

    # instead of subtracting the max_log in the exponent we divide by exp(max_log)
    max_log = jnp.max(jnp.log(jnp.abs(rotated_amplitude_contributions) + 1e-30))
    scaled_exp = rotated_amplitude_contributions * jnp.exp(-max_log)

    log_sum_exp = max_log + jnp.log(jnp.abs(jnp.sum(scaled_exp)) + 1e-30)
    return 2 * log_sum_exp


def loss_fn_mine(measurement, basis, params_lambda, params_mu):

    # get the free energies for all computational basis vectors to construct the full state vector
    computational_basis_vectors = get_computational_basis_vectors(measurement.shape[0])
    free_energy_lambda = dummy_free_energy(computational_basis_vectors, params_lambda)
    free_energy_mu = dummy_free_energy(computational_basis_vectors, params_mu)

    rotation_matrix = construct_rotation_matrix(basis)          # (2**n, 2**n)

    # we pick the row corresponding to the rotated amplitude. it contains all the weights for the computational basis amplitudes
    rotation_weights = rotation_matrix[bitstring_to_int(measurement)]       # (2**n,)

    log_prob = rotated_log_prob_stable(rotation_weights, free_energy_lambda, free_energy_mu)
    return log_prob


# per sample gradient function (sadly we cannot take batch gradients later for the natural gradient)
autodiff_grad_fn_mine = jax.grad(loss_fn_mine, argnums=3)

# use vmap to compute the gradient for all measurements
autodiff_grads_mine = jax.vmap(autodiff_grad_fn_mine, in_axes=(0, None, None, None))(measurements, basis, params_lambda, params_mu)

# since we get a bunch of pytrees, we need to use the tree map to calculate the mean over the batch
autodiff_grad_mine = tree.map(lambda x: -jnp.mean(x, axis=0), autodiff_grads_mine)
autodiff_grad_mine

{'W': Array([-9.99703217, -0.00019717,  0.01772640, -5.99372959, -0.00019697,
        -4.72853422,  0.02901845, -0.00020343, -0.00019697, -0.00019717],      dtype=float32),
 'b': Array(-0.00019916, dtype=float32)}

In [31]:
#### EXPLICIT GRADIENTS FROM MY DERIVATION


def grad_fn_mine(measurements, basis, params_lambda, params_mu):
    computational_basis_vectors = get_computational_basis_vectors(measurements.shape[1])

    free_energy_lambda = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_lambda)
    free_energy_mu = jax.vmap(dummy_free_energy, (0, None))(computational_basis_vectors, params_mu)
    free_energy_mu_grads = jax.vmap(lambda s: jax.grad(lambda p: dummy_free_energy(s, p))(params_mu))(computational_basis_vectors)

    rotation_matrix = construct_rotation_matrix(basis)

    def per_sample(measurement):
        idx = bitstring_to_int(measurement)
        rotated_exponent = jnp.log(rotation_matrix[idx]) - 0.5 * free_energy_lambda - 0.5j * free_energy_mu
        rotated_exponent = jnp.where(jnp.isfinite(rotated_exponent), rotated_exponent, -1e30 + 0j)
        gradient_weights = jnp.imag(softmax(rotated_exponent))

        def apply_to_leaf(leaf):
            if leaf.ndim == 1:             # bias gradient
                return -jnp.sum(gradient_weights * leaf)
            else:                          # weight gradient
                return -jnp.sum(gradient_weights[:, None] * leaf, axis=0)

        return tree.map(apply_to_leaf, free_energy_mu_grads)

    grads_batch = jax.vmap(per_sample)(measurements)         # stacked pytree
    return tree.map(lambda x: jnp.mean(x, 0), grads_batch)


explicit_grad_mine = grad_fn_mine(measurements, basis, params_lambda, params_mu)

print("explicit mine:", explicit_grad_mine)
print("autodiff mine:", autodiff_grad_mine)

abs_tol = 1e-1
print(f"Within tol. {abs_tol} (explicit mine vs autodiff mine):", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, atol=abs_tol), explicit_grad_mine, autodiff_grad_mine)))

explicit mine: {'W': Array([-9.98472214,  0.00001884,  0.01825257, -5.98650360,  0.00001876,
       -4.72285175,  0.02916548,  0.00002033,  0.00001876,  0.00001884],      dtype=float32), 'b': Array(-0.00004276, dtype=float32)}
autodiff mine: {'W': Array([-9.99703217, -0.00019717,  0.01772640, -5.99372959, -0.00019697,
       -4.72853422,  0.02901845, -0.00020343, -0.00019697, -0.00019717],      dtype=float32), 'b': Array(-0.00019916, dtype=float32)}
Within tol. 0.1 (explicit mine vs autodiff mine): True


In [32]:
#### GRADIENTS USING AUTODIFF OF PAPER DERIVATION (VERY STUPID 1 TO 1)


def p_lambda(sigma, params_lambda):
    free_energy_lambda = dummy_free_energy(sigma, params_lambda)
    return jnp.exp(-free_energy_lambda)

def phi_mu(sigma, params_mu):
    free_energy_mu = dummy_free_energy(sigma, params_mu)
    return -free_energy_mu

def loss_fn_paper(
        measurements: jnp.ndarray,
        basis: tuple[int, ...],
        params_lambda: dict,
        params_mu: dict) -> jnp.ndarray:

    computational_basis_vectors = get_computational_basis_vectors(measurements.shape[1])  # (2**n, n)

    p_lambda_values = jax.vmap(p_lambda, (0, None))(computational_basis_vectors, params_lambda)  # (2**n,)
    phi_mu_values = jax.vmap(phi_mu, (0, None))(computational_basis_vectors, params_mu)  # (2**n,)

    rotation_matrix = construct_rotation_matrix(basis)  # (2**n, 2**n)

    def get_log_probability(measurement):
        idx = bitstring_to_int(measurement)
        rotation_vector = rotation_matrix[idx]

        sqrt_p_lambda = jnp.sqrt(p_lambda_values)
        exp_phi_mu = jnp.exp(1j * phi_mu_values / 2)

        rotated_amp = jnp.vdot(rotation_vector, sqrt_p_lambda * exp_phi_mu)

        log_probability = jnp.log(jnp.abs(rotated_amp))
        return log_probability

    log_probs = []
    for measurement in measurements:
        contribution = get_log_probability(measurement)
        log_probs.append(contribution + contribution.conj())

    log_probs = jnp.array(log_probs)

    loss = -jnp.mean(log_probs)
    return loss


autodiff_grad_fn_paper = jax.grad(loss_fn_paper, argnums=3)
autodiff_grad_paper = autodiff_grad_fn_paper(measurements, basis, params_lambda, params_mu)

print("autodiff paper:", autodiff_grad_paper)
print("autodiff mine :", autodiff_grad_mine)
print("explicit mine :", explicit_grad_mine)

abs_tol = 1e-2
print(f"Within tol. {abs_tol} (explicit mine vs autodiff paper):", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, atol=abs_tol), explicit_grad_mine, autodiff_grad_paper)))
print(f"Within tol. {abs_tol} (autodiff mine vs autodiff paper):", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, atol=abs_tol), autodiff_grad_mine, autodiff_grad_paper)))

autodiff paper: {'W': Array([-10.00207520, -0.00018311,  0.01821899, -5.99688721, -0.00009155,
       -4.73104858,  0.02920532, -0.00021362, -0.00009155, -0.00018311],      dtype=float32), 'b': Array(-0.00012445, dtype=float32)}
autodiff mine : {'W': Array([-9.99703217, -0.00019717,  0.01772640, -5.99372959, -0.00019697,
       -4.72853422,  0.02901845, -0.00020343, -0.00019697, -0.00019717],      dtype=float32), 'b': Array(-0.00019916, dtype=float32)}
explicit mine : {'W': Array([-9.98472214,  0.00001884,  0.01825257, -5.98650360,  0.00001876,
       -4.72285175,  0.02916548,  0.00002033,  0.00001876,  0.00001884],      dtype=float32), 'b': Array(-0.00004276, dtype=float32)}
Within tol. 0.01 (explicit mine vs autodiff paper): False
Within tol. 0.01 (autodiff mine vs autodiff paper): True


In [46]:
#### GRADIENTS USING THE REDUCED ROTATOR

def rotated_log_prob_stable_real(rotation_weights, free_energy_lambda, free_energy_mu):
    # weights are stable known values and the phase part is bounded by +-1. We do log_sum_exp only the real part

    computational_log_magnitudes = -0.5 * free_energy_lambda
    computational_phases = -0.5j * free_energy_mu

    max_computational_log_magnitude = jnp.max(computational_log_magnitudes)

    scaled_computational_amplitudes = jnp.exp(computational_log_magnitudes - max_computational_log_magnitude + computational_phases)

    scaled_measurement_amplitude = jnp.vdot(rotation_weights, scaled_computational_amplitudes)

    log_measurement_amplitude = max_computational_log_magnitude + jnp.log(jnp.abs(scaled_measurement_amplitude) + 1e-30)
    return 2 * log_measurement_amplitude


def loss_fn_local(measurement, basis, params_lambda, params_mu):

    local_indices = jnp.array(jnp.nonzero(basis != 0, size=2, fill_value=-1)[0])

    # get the free energies for all computational basis vectors to construct the full state vector
    local_rotation_matrix = construct_rotation_matrix(basis[local_indices])

    # for outcome 00 we pick first row, for 01 second row, etc.
    local_rotation_weights = local_rotation_matrix[bitstring_to_int(measurement[local_indices])]

    # amplitudes mismatching with our Z measurements are 0. There are only 4 remaining amplitudes with the local variations
    local_measurement_combos = jnp.array([[0,0], [0,1], [1,0], [1,1]], dtype=measurement.dtype)
    local_computational_basis_vectors = jnp.tile(measurement, (4, 1)).at[:, local_indices].set(local_measurement_combos)  # (4, n)

    local_free_energy_lambda = dummy_free_energy(local_computational_basis_vectors, params_lambda)
    local_free_energy_mu = dummy_free_energy(local_computational_basis_vectors, params_mu)

    rotated_log_prob = rotated_log_prob_stable_real(local_rotation_weights, local_free_energy_lambda, local_free_energy_mu)
    return rotated_log_prob


autodiff_grad_fn_local = jax.grad(loss_fn_local, argnums=3)

In [47]:
#### COMPARE WITH SOME LOCAL MEASUREMENTS

autodiff_grads_mine = jax.vmap(autodiff_grad_fn_mine, in_axes=(0, None, None, None))(measurements_local, basis_local, params_lambda, params_mu)
autodiff_grads_local = jax.vmap(autodiff_grad_fn_local, in_axes=(0, None, None, None))(measurements_local, basis_local, params_lambda, params_mu)

autodiff_grad_mine = tree.map(lambda x: -jnp.mean(x, axis=0), autodiff_grads_mine)
autodiff_grad_local = tree.map(lambda x: -jnp.mean(x, axis=0), autodiff_grads_local)

print(f"autodiff mine (local): {autodiff_grad_mine}")
print(f"autodiff local (local): {autodiff_grad_local}")

abs_tol = 1e-6
print(f"Within tol. {abs_tol} (autodiff mine vs autodiff local):", tree.all(tree.map(lambda a, b: jnp.allclose(a, b, atol=abs_tol), autodiff_grad_mine, autodiff_grad_local)))

autodiff mine (local): {'W': Array([-0.00000017, -0.00000017, -0.00000017, -0.00000017, -0.00000017,
       -4.72804928,  0.02910010, -0.00000017, -0.00000017, -0.00000017],      dtype=float32), 'b': Array(-0.00000019, dtype=float32)}
autodiff local (local): {'W': Array([-0.00000031, -0.00000018, -0.00000017, -0.00000004, -0.00000031,
       -4.72804594,  0.02910029, -0.00000004, -0.00000031, -0.00000018],      dtype=float32), 'b': Array(-0.00000017, dtype=float32)}
Within tol. 1e-06 (autodiff mine vs autodiff local): True
