In [None]:
from typing import Tuple, Dict, Any
from typing import Optional
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from flax.core import freeze, unfreeze
from jax.nn.initializers import normal, zeros
from collections.abc import Callable, Sequence


class DoubleRBM(nn.Module):
    n_visible: int
    n_hidden: int
    k: int = 1
    n_chains: int = 100

    def setup(self):
        self.W_amp = self.param("W_amp", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        self.b_amp = self.param("b_amp", nn.initializers.zeros, (self.n_visible,))
        self.c_amp = self.param("c_amp", nn.initializers.zeros, (self.n_hidden,))

        self.W_pha = self.param("W_pha", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        self.b_pha = self.param("b_pha", nn.initializers.zeros, (self.n_visible,))
        self.c_pha = self.param("c_pha", nn.initializers.zeros, (self.n_hidden,))

        self.rotators = {
            'X': jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64) / jnp.sqrt(2),
            'Y': jnp.array([[1, -1j], [1, 1j]], dtype=jnp.complex64) / jnp.sqrt(2),
        }

    def __call__(
            self,
            data_dict: Dict[str, jnp.ndarray],
            aux_vars: Dict[str, Any]) -> Tuple[jnp.ndarray, Dict[str, Any]]:

        random_key = aux_vars.get("random_key")
        persistent_chains = aux_vars.get("v_persistent", None)

        aux_vars = dict()

        # amplitude training
        comp_basis = 'Z' * self.n_visible
        if len(data_dict) == 1 and comp_basis in data_dict:
            data_batch = data_dict[comp_basis]
            pcd_loss, persistent_chains, random_key = self._loss_amp(data_batch, persistent_chains, random_key)

            if persistent_chains is not None:
                aux_vars["v_persistent"] = persistent_chains

            aux_vars["random_key"] = random_key

            return pcd_loss, aux_vars


        # phase training
        all_two_different = all(sum(b != 'Z' for b in basis) == 2 for basis in data_dict)
        if all_two_different:
            phase_loss = self._loss_phase(data_dict)
            return phase_loss, aux_vars

        raise ValueError("Encountered batch of unsupported basis.")


    def _loss_amp(self, data_batch, persistent_chains, random_key):
        W = self.variables["params"]["W_amp"]
        b = self.variables["params"]["b_amp"]
        c = self.variables["params"]["c_amp"]

        if persistent_chains is None:
            random_key, random_key_bern = jax.random.split(random_key)
            chains = jax.random.bernoulli(random_key_bern, p=0.5, shape=(self.n_chains, self.n_visible))
            model_batch, random_key = self._gibbs_sample(W, b, c, chains, random_key, k=self.k)
        else:
            model_batch, random_key = self._gibbs_sample(W, b, c, persistent_chains, random_key, k=self.k)
            persistent_chains = model_batch

        model_batch = jax.lax.stop_gradient(model_batch)  # stopping gradient tracking before computing the loss

        # stacking the batches here could maybe yield some performance, not sure though since half the data is gradient detached
        free_energy_data = self._free_energy(W, b, c, data_batch)
        free_energy_model = self._free_energy(W, b, c, model_batch)
        pcd_loss = jnp.mean(free_energy_data) - jnp.mean(free_energy_model)

        return pcd_loss, persistent_chains, random_key


    def _loss_phase(self, data_dict: Dict[str, jnp.ndarray]) -> jnp.ndarray:
        total_loss = 0.0

        for basis, batch in data_dict.items():  # batch: shape (B, n)
            amps = self.get_rotated_amplitude(batch, basis)  # shape (B,)
            log_probs = jnp.log(jnp.abs(amps) ** 2 + 1e-10)   # shape (B,)
            total_loss -= jnp.mean(log_probs)  # NLL

        return total_loss



    @staticmethod
    def _free_energy(W, b, c, v):
        visible_term = jnp.dot(v, b)
        hidden_term  = jnp.sum(jax.nn.softplus(v @ W + c), axis=-1)
        free_energy = -visible_term - hidden_term
        return free_energy

    # amplitude RBM specific

    @staticmethod
    def _gibbs_step(i, state, W, b, c, T=1.0):
        v, key = state

        key, h_key, v_key = jax.random.split(key, 3)

        h_logits = (v @ W + c) / T
        h_probs = jax.nn.sigmoid(h_logits)
        h = jax.random.bernoulli(h_key, h_probs).astype(jnp.float32)

        v_logits = (h @ W.T + b) / T
        v_probs = jax.nn.sigmoid(v_logits)
        v = jax.random.bernoulli(v_key, v_probs).astype(jnp.float32)
        return v, key

    @staticmethod
    def _gibbs_sample(W, b, c, v_init, rng, k=1, T=1.0):
        body_fun = lambda i, state: DoubleRBM._gibbs_step(i, state, W, b, c, T)
        v_final, key = jax.lax.fori_loop(0, k, body_fun, (v_init, rng))
        return v_final, key

    # phase RBM specific

    def get_amplitude(self, sigma_batch: jnp.ndarray) -> jnp.complex64:
        W_amp = self.variables["params"]["W_amp"]
        b_amp = self.variables["params"]["b_amp"]
        c_amp = self.variables["params"]["c_amp"]

        W_pha = self.variables["params"]["W_pha"]
        b_pha = self.variables["params"]["b_pha"]
        c_pha = self.variables["params"]["c_pha"]

        F_amp = self._free_energy(W_amp, b_amp, c_amp, sigma_batch)
        F_pha = self._free_energy(W_pha, b_pha, c_pha, sigma_batch)

        return jnp.exp(-0.5 * F_amp) * jnp.exp(-0.5j * F_pha)


    def get_rotated_amplitude(self, sigma_b: jnp.ndarray, basis: Sequence[str]) -> jnp.ndarray:
        B, n = sigma_b.shape

        non_z = [i for i, b in enumerate(basis) if b != 'Z']
        if len(non_z) != 2:
            raise ValueError("Only bases with exactly two non-Z entries are supported.")

        j, k = non_z
        Rj = self.rotators[basis[j]]
        Rk = self.rotators[basis[k]]
        U = jnp.kron(Rj, Rk)  # shape (4, 4)

        # 4 local bit combinations to insert at positions j and k
        local_bit_combos = jnp.array([[0., 0.], [0., 1.], [1., 0.], [1., 1.]], dtype=sigma_b.dtype)

        # Expand: (B, n) → (4, B, n) with 4 modified versions per sample
        sigma_b_tiled = jnp.repeat(sigma_b[None, :, :], 4, axis=0)  # (4, B, n)
        sigma_b_modified = sigma_b_tiled.at[:, :, [j, k]].set(local_bit_combos[:, None, :])  # (4, B, n)

        # Reshape to (4B, n)
        sigma_b_flat = sigma_b_modified.transpose(1, 0, 2).reshape(4 * B, n)

        # Single call to get_amplitudes
        psis_flat = self.get_amplitudes(sigma_b_flat)  # shape (4B,)
        psis = psis_flat.reshape(B, 4)  # shape (B, 4)

        # Compute index into rotated basis
        idx_in = (sigma_b[:, j].astype(int) << 1) | sigma_b[:, k].astype(int)  # shape (B,)

        # Apply U[:, idx_in] ⋅ psis per batch entry
        U_selected = U[:, idx_in]  # shape (4, B)
        amp = jnp.einsum("bi,ib->b", psis, U_selected)  # shape (B,)

        return amp