In [25]:
import re
from pathlib import Path
from typing import Optional, Tuple, Dict, Any, Sequence
from datetime import datetime

import numpy as np

import jax
import jax.lax
from jax.random import PRNGKey
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from flax.training import checkpoints
from flax import linen as nn

import matplotlib.pyplot as plt



from flax.training import checkpoints
from pathlib import Path

import warnings

# Suppress specific Orbax sharding warning
warnings.filterwarnings(
    "ignore",
    message=(
        "Couldn't find sharding info under RestoreArgs.*"
    ),
    category=UserWarning,
    module="orbax.checkpoint.type_handlers"
)

data_dir = "./data"
model_dir = "./models"
model_prefix = "rbm_amp_202506031229_0"

print(f"Data resides in                         : {data_dir}")
print(f"Amplitude RBM checkpoint to be loaded   : {model_dir}/{model_prefix}")

Data resides in                         : ./data
Amplitude RBM checkpoint to be loaded   : ./models/rbm_amp_202506031229_0


In [26]:
class MultiBasisDataLoader:
    def __init__(self, data_dict: dict[str, jnp.ndarray],
                 batch_size: int = 128,
                 shuffle: bool = True,
                 drop_last: bool = False,
                 seed: int = 0):
        lengths = [len(v) for v in data_dict.values()]
        if len(set(lengths)) != 1:
            raise ValueError(f"All arrays must have the same length, got: {lengths}")

        self.data = data_dict
        self.n = lengths[0]
        self.bs = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.rng = np.random.default_rng(seed)

        self.idx_slices = [
            (i, i + batch_size)
            for i in range(0, self.n, batch_size)
            if not drop_last or i + batch_size <= self.n
        ]

    def __iter__(self):
        self.order = np.arange(self.n)
        if self.shuffle:
            self.rng.shuffle(self.order)
        self.slice_idx = 0
        return self

    def __next__(self):
        if self.slice_idx >= len(self.idx_slices):
            raise StopIteration
        s, e = self.idx_slices[self.slice_idx]
        self.slice_idx += 1
        return {k: v[self.order[s:e]] for k, v in self.data.items()}

    def __len__(self):
        return len(self.idx_slices)


def load_measurements(folder: str, file_pattern: str = "w_*.txt") -> dict[str, jnp.ndarray]:
    out: dict[str, jnp.ndarray] = {}

    for fp in Path(folder).glob(file_pattern):
        basis = fp.stem.split("_")[2]

        bitstrings = []
        with fp.open() as f:
            for line in f:
                bitstring = np.fromiter((c.islower() for c in line.strip()), dtype=np.float32)
                bitstrings.append(bitstring)

        arr = jnp.asarray(np.stack(bitstrings))
        if basis in out:
            out[basis] = jnp.concatenate([out[basis], arr], axis=0)
        else:
            out[basis] = arr

    return out

In [27]:
data_dict = load_measurements("data/", "w_*.txt")
keys_pha = [k for k in data_dict if re.fullmatch(r"^(?!Z+$).*", k)]
dict_pha = {k: data_dict[k] for k in keys_pha}

In [28]:
class PairPhaseRBM(nn.Module):
    n_visible: int
    n_hidden: int

    def setup(self):
        zeros = lambda shape: jnp.zeros(shape, dtype=jnp.float32)

        # Amplitude RBM (frozen parameters)
        self.W_amp = self.variable('amp_state', 'W_amp', zeros, (self.n_visible, self.n_hidden))
        self.b_amp = self.variable('amp_state', 'b_amp', zeros, (self.n_visible,))
        self.c_amp = self.variable('amp_state', 'c_amp', zeros, (self.n_hidden,))

        # Phase RBM (trainable parameters)
        self.W_pha = self.param('W_pha', nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        self.b_pha = self.param('b_pha', nn.initializers.zeros, (self.n_visible,))
        self.c_pha = self.param('c_pha', nn.initializers.zeros, (self.n_hidden,))

        # Rotation matrices
        sqrt2 = jnp.sqrt(2.0)
        self.rotators = {
            'X': jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64) / sqrt2,
            'Y': jnp.array([[1, -1j], [1, 1j]], dtype=jnp.complex64) / sqrt2,
        }

    @staticmethod
    def _free_energy(W, b, c, v):
        visible_term = jnp.dot(v, b)
        hidden_term = jnp.sum(jax.nn.softplus(jnp.dot(v, W) + c), axis=-1)
        return -visible_term - hidden_term

    @staticmethod
    def _free_energy_grad(W, b, c, v):
        pre = jnp.dot(v, W) + c
        sig = jax.nn.sigmoid(pre)
        dW = jnp.einsum('bi,bj->bij', v, sig)
        return jnp.concatenate([dW.reshape(v.shape[0], -1), v, sig], axis=1)

    def rotated_log_psi_and_grad(self, sigma_b, basis):
        W_amp, b_amp, c_amp = self.W_amp.value, self.b_amp.value, self.c_amp.value
        W_pha, b_pha, c_pha = self.W_pha, self.b_pha, self.c_pha

        B, n = sigma_b.shape
        non_z = [i for i, p in enumerate(basis) if p != 'Z']
        if len(non_z) != 2:
            raise ValueError("Basis must have exactly two non-Z qubits.")
        j, k = non_z

        Rj, Rk = self.rotators[basis[j]], self.rotators[basis[k]]
        U = jnp.kron(Rj, Rk)

        combos = jnp.array([[0., 0.], [0., 1.], [1., 0.], [1., 1.]], dtype=sigma_b.dtype)
        tiled = jnp.tile(sigma_b[:, None, :], (1, 4, 1))
        modified = tiled.at[:, :, [j, k]].set(combos[None, :, :])
        flat = modified.reshape(B * 4, n)

        F_amp = self._free_energy(W_amp, b_amp, c_amp, flat)
        F_pha = self._free_energy(W_pha, b_pha, c_pha, flat)

        log_mag = (-0.5 * F_amp).reshape(B, 4)
        angle = (-0.5 * F_pha).reshape(B, 4)
        M = jnp.max(log_mag, axis=1, keepdims=True)
        scaled = jnp.exp(log_mag - M + 1j * angle)

        idx = (sigma_b[:, j].astype(int) << 1) | sigma_b[:, k].astype(int)
        Uc = U[:, idx].T
        psi_rot = jnp.sum(Uc * scaled, axis=1)

        grad_F_amp = self._free_energy_grad(W_amp, b_amp, c_amp, flat).reshape(B, 4, -1)
        grad_F_pha = self._free_energy_grad(W_pha, b_pha, c_pha, flat).reshape(B, 4, -1)

        grad_logpsi = -0.5 * grad_F_amp + -0.5j * grad_F_pha
        psi_weighted_grad = jnp.einsum("bij,bi->bj", grad_logpsi, Uc * scaled)
        psi_ratio = psi_weighted_grad / (psi_rot[:, None] + 1e-12)

        split = grad_F_amp.shape[-1]
        grad_lambda = psi_ratio[:, :split].real
        grad_mu = -psi_ratio[:, split:].imag

        return grad_lambda, grad_mu, jnp.log(jnp.abs(psi_rot) + 1e-12), M.squeeze()

    def __call__(self, data_batch_dict):
        total_loss = 0.
        for basis, sigma_b in data_batch_dict.items():
            _, _, log_amp, M = self.rotated_log_psi_and_grad(sigma_b, basis)
            total_loss += -2. * jnp.mean(log_amp + M)
        return total_loss

In [29]:
@jax.jit
def train_step_pha(state: TrainState, amp_vars: Dict[str, jnp.ndarray], batch_dict: Dict[str, jnp.ndarray]) -> Tuple[TrainState, jnp.ndarray]:
    model_var_dict = { 'params': state.params, 'amp_state': amp_vars }
    def loss_fn(params):
        model_var_dict['params'] = params
        return state.apply_fn(model_var_dict, batch_dict)

    grads = {}
    loss = 0.0
    for basis, sigma_b in batch_dict.items():
        grad_lambda, grad_mu, log_amp, M = state.apply_fn(model_var_dict, method=PairPhaseRBM.rotated_log_psi_and_grad, sigma_b=sigma_b, basis=basis)
        loss += -2.0 * jnp.mean(log_amp + M)

        flat_mu = jnp.mean(grad_mu, axis=0)

        W_shape = state.params['W_pha'].shape
        b_shape = state.params['b_pha'].shape
        c_shape = state.params['c_pha'].shape

        import numpy as np
        nW = np.prod(W_shape)
        nb = np.prod(b_shape)
        nc = np.prod(c_shape)

        W_mu = flat_mu[:nW].reshape(W_shape)
        b_mu = flat_mu[nW:nW+nb].reshape(b_shape)
        c_mu = flat_mu[nW+nb:].reshape(c_shape)

        grads_basis = {
            'W_pha': W_mu,
            'b_pha': b_mu,
            'c_pha': c_mu
        }

        grads = jax.tree_util.tree_map(lambda g1, g2: g1 + g2, grads, grads_basis) if grads else grads_basis

    grads = jax.tree_util.tree_map(lambda g: g / len(batch_dict), grads)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss


def train_phase_rbm(state_pha: TrainState, amp_vars: Dict[str, jnp.ndarray], loader: MultiBasisDataLoader, num_epochs: int) -> Tuple[TrainState, Dict[int, float]]:
    metrics = {}
    for epoch in range(num_epochs):
        tot_loss = 0.0
        batches = 0
        for batch_dict in loader:
            state_pha, loss = train_step_pha(state_pha, amp_vars, batch_dict)
            tot_loss += loss
            batches += 1
        metrics[epoch] = {"loss_pha": float(tot_loss / batches)}
        print(f"Epoch {epoch+1}/{num_epochs} │ Loss: {metrics[epoch]['loss_pha']:.4f}")
    return state_pha, metrics

In [30]:
params_amp = checkpoints.restore_checkpoint(
    ckpt_dir=str(Path(model_dir).resolve()),
    target=None,
    prefix=model_prefix
)

W_amp = params_amp["W"]
b_amp = params_amp["b"]
c_amp = params_amp["c"]

amp_vars = {
    "W_amp": W_amp,
    "b_amp": b_amp,
    "c_amp": c_amp,
}

In [32]:
batch_size    = 6400
lr            = 1e-3
num_epochs    = 50 # will be increased but currently no conclusive downwards trend

rng = jax.random.PRNGKey(42)

loader_pha = MultiBasisDataLoader(dict_pha, batch_size=128)


model_pha = PairPhaseRBM(n_visible=visible_units, n_hidden=hidden_units)
dummy_dict = next(iter(loader_pha))
variables_pha  = model_pha.init(rng, dummy_dict)

#tx = optax.chain(
#    natgrad_diag(damping=1e-3),
#    optax.scale(-lr),  # standard gradient descent on natural step
#)

#state_pha = TrainState.create(
#    apply_fn=model_pha.apply,
#    params=variables_pha['params'],
#    tx=tx,
#)
optimizer_pha = optax.adam(learning_rate=lr)
state_pha = TrainState.create(apply_fn=model_pha.apply, params=variables_pha['params'], tx=optimizer_pha)


TypeError: dot_general requires contracting dimensions to have the same shape, got (10,) and (16,).