In [1]:
import jax
import jax.numpy as jnp

In [2]:
SQRT2 = jnp.sqrt(2.0)
rot_X = jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64) / SQRT2
rot_Y = jnp.array([[1, -1j], [1, 1j]], dtype=jnp.complex64) / SQRT2
combos = jnp.array([[0., 0.], [0., 1.], [1., 0.], [1., 1.]], dtype=jnp.float32)


def dummy_free_energy_amp(v):
    return jnp.sum(v * 0.5, axis=-1)

def dummy_free_energy_pha(v):
    return jnp.sum(v * 0.1, axis=-1)

In [4]:
def compare_loss_expressions(data, basis_ids):
    B, n = data.shape
    j, k = jnp.nonzero(basis_ids != 0, size=2, fill_value=-1)[0]
    b1, b2 = basis_ids[j], basis_ids[k]

    Rj = jax.lax.switch(b1 - 1, [lambda: rot_X, lambda: rot_Y])
    Rk = jax.lax.switch(b2 - 1, [lambda: rot_X, lambda: rot_Y])
    U = jnp.kron(Rj, Rk)  # (4, 4)

    sigma_b_tiled = jnp.tile(data[:, None, :], (1, 4, 1))
    sigma_mod = sigma_b_tiled.at[:, :, [j, k]].set(combos[None, :, :])
    sigma_flat = sigma_mod.reshape(B * 4, n)

    F_amp = dummy_free_energy_amp(sigma_flat).reshape(B, 4)
    F_pha = dummy_free_energy_pha(sigma_flat).reshape(B, 4)

    # naive computation exactly like the formula
    naive_scaled = jnp.exp(-0.5 * F_amp + 1j * -0.5 * F_pha)

    idx_in = (data[:, j].astype(int) << 1) | data[:, k].astype(int)
    U_cols = U[:, idx_in].T  # (B, 4)
    naive_S = jnp.sum(U_cols * naive_scaled, axis=1)
    naive_log_probs = jnp.log(jnp.abs(naive_S) ** 2)

    # log-sum-exp trick
    log_mag = -0.5 * F_amp
    angle = -0.5 * F_pha
    M = jnp.max(log_mag, axis=1, keepdims=True)
    stable_scaled = jnp.exp(log_mag - M + 1j * angle)
    stable_S = jnp.sum(U_cols * stable_scaled, axis=1)
    stable_log_probs = 2.0 * (M.squeeze() + jnp.log(jnp.abs(stable_S) + 1e-12))

    return naive_log_probs, stable_log_probs

In [10]:
key = jax.random.PRNGKey(0)
B, n = 6, 8
data = jax.random.bernoulli(key, p=0.5, shape=(B, n)).astype(jnp.float32)
basis_ids = jnp.array([1, 2, 0, 0, 0, 0, 0, 0], dtype=jnp.int32)  # XYZZZZ

naive, stable = compare_loss_expressions(data, basis_ids)
print("Naive:", naive)
print("Stable:", stable)
print("Difference:", naive - stable)

Naive: [-4.2134013  -4.713401   -9.343031   -4.213401   -0.08376761 -4.713401  ]
Stable: [-4.213401   -4.713402   -9.343031   -4.213401   -0.08376765 -4.713401  ]
Difference: [-4.7683716e-07  9.5367432e-07  0.0000000e+00  0.0000000e+00
  4.4703484e-08  0.0000000e+00]
