In [21]:
import re
from pathlib import Path

import numpy as np

import jax
import jax.lax
from jax.random import PRNGKey
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from flax import linen as nn

import matplotlib.pyplot as plt

from typing import Optional, Tuple, Dict, Any, Sequence


data_dir = "./data"
print(f"Data resides in        : {data_dir}")

Data resides in        : ./data


In [22]:
class MultiBasisDataLoader:
    def __init__(self, data_dict: dict[str, jnp.ndarray],
                 batch_size: int = 128,
                 shuffle: bool = True,
                 drop_last: bool = False,
                 seed: int = 0):
        lengths = [len(v) for v in data_dict.values()]
        if len(set(lengths)) != 1:
            raise ValueError(f"All arrays must have the same length, got: {lengths}")

        self.data = data_dict
        self.n = lengths[0]
        self.bs = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.rng = np.random.default_rng(seed)

        self.idx_slices = [
            (i, i + batch_size)
            for i in range(0, self.n, batch_size)
            if not drop_last or i + batch_size <= self.n
        ]

    def __iter__(self):
        self.order = np.arange(self.n)
        if self.shuffle:
            self.rng.shuffle(self.order)
        self.slice_idx = 0
        return self

    def __next__(self):
        if self.slice_idx >= len(self.idx_slices):
            raise StopIteration
        s, e = self.idx_slices[self.slice_idx]
        self.slice_idx += 1
        return {k: v[self.order[s:e]] for k, v in self.data.items()}


def load_measurements(folder: str, file_pattern: str = "w_*.txt") -> dict[str, jnp.ndarray]:
    out: dict[str, jnp.ndarray] = {}

    for fp in Path(folder).glob(file_pattern):
        basis = fp.stem.split("_")[2]

        bitstrings = []
        with fp.open() as f:
            for line in f:
                bitstring = np.fromiter((c.islower() for c in line.strip()), dtype=np.float32)
                bitstrings.append(bitstring)

        arr = jnp.asarray(np.stack(bitstrings))
        if basis in out:
            out[basis] = jnp.concatenate([out[basis], arr], axis=0)
        else:
            out[basis] = arr

    return out

In [23]:
data_dict = load_measurements("data/", "w_*.txt")

keys_amp = [k for k in data_dict if re.fullmatch(r"^Z+$", k)]
keys_pha = [k for k in data_dict if re.fullmatch(r"^(?!Z+$).*", k)]
dict_amp = {k: data_dict[k] for k in keys_amp}
dict_pha = {k: data_dict[k] for k in keys_pha}

loader_amp = MultiBasisDataLoader(dict_amp, batch_size=128)
loader_pha = MultiBasisDataLoader(dict_pha, batch_size=128)

In [24]:
import jax, jax.numpy as jnp
from jax.random import PRNGKey
from flax import linen as nn
from typing import Tuple

class PcdRBM(nn.Module):
    n_visible: int
    n_hidden: int
    k: int = 1
    T: float = 1.0
    n_chains: int = 128

    def setup(self):
        self.W = self.param("W", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        self.b = self.param("b", nn.initializers.zeros, (self.n_visible,))
        self.c = self.param("c", nn.initializers.zeros, (self.n_hidden,))

        # initialise only once; RNG needed only here
        def init_chain():
            k = self.make_rng("pcd_init")
            return jax.random.bernoulli(
                k, p=0.5,
                shape=(self.n_chains, self.n_visible)
            ).astype(jnp.float32)

        self.v_chain = self.variable("pcd_state", "v_chain", init_chain)

    def __call__(self, data_batch: jnp.ndarray, key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]:
        gibbs_step_fn = lambda i, s: self._gibbs_step(s, self.W, self.b, self.c, self.T)
        model_batch, key = jax.lax.fori_loop(0, self.k, gibbs_step_fn, (self.v_chain.value, key))
        self.v_chain.value = jax.lax.stop_gradient(model_batch) # cut off gradients after Gibbs sampling

        loss = jnp.mean(self._free_energy(data_batch)) - jnp.mean(self._free_energy(self.v_chain.value))
        return loss, key

    def _free_energy(self, v):
        return -(v @ self.b) - jnp.sum(jax.nn.softplus(v @ self.W + self.c), -1)

    @staticmethod
    def _gibbs_step(state, W, b, c, T):
        v, key = state
        key, h_key, v_key = jax.random.split(key, 3)
        h = jax.random.bernoulli(h_key, jax.nn.sigmoid((v @ W + c)/T)).astype(jnp.float32)
        v = jax.random.bernoulli(v_key, jax.nn.sigmoid((h @ W.T + b)/T)).astype(jnp.float32)
        return v, key

In [48]:
from flax import struct
from flax.training import train_state
from flax.core import FrozenDict
from typing import Any

@struct.dataclass
class PcdTrainState(train_state.TrainState):
    pcd_state: FrozenDict[str, Any] = struct.field(pytree_node=True,
                                                   default_factory=lambda: FrozenDict())


@jax.jit
def train_step_amp(
        state: PcdTrainState,
        batch_dict: Dict[str, jnp.ndarray],
        key: PRNGKey
) -> Tuple[PcdTrainState, jnp.ndarray, PRNGKey]:

    if len(batch_dict) != 1:
        raise ValueError("Batch dictionary must contain exactly one entry.")

    (name, batch), = batch_dict.items()
    if set(name) != {'Z'}:
        raise ValueError(f"Batch key must consist only of 'Z', got: {name}")

    # Correct loss_fn — returns ((loss, key), mutable)
    loss_fn = lambda params: state.apply_fn(
        {"params": params, "pcd_state": state.pcd_state},
        batch,
        key,
        mutable=["pcd_state"]
    )

    ((loss, key), mutable), grads = jax.value_and_grad(
        lambda p: loss_fn(p)[0],  # grads w.r.t. scalar loss only
        has_aux=True
    )(state.params)

    # 🔥 This is now a traced FrozenDict — safe to index!
    state = state.apply_gradients(grads=grads).replace(
        pcd_state=mutable["pcd_state"]
    )
    return state, loss, key





def train_amp_rbm(
        state: TrainState,
        loader: MultiBasisDataLoader,
        num_epochs: int,
        key: PRNGKey) -> Tuple[TrainState, Dict[int, float]]:

    metrics = {}

    for epoch in range(num_epochs):
        tot_loss = 0.0
        batches = 0

        for batch_dict in loader:
            state, loss, key = train_step_amp(state, batch_dict, key)
            tot_loss += loss
            batches += 1

        metrics[epoch] = {"loss_amp": float(tot_loss / batches)}
        print(f"Epoch {epoch+1}/{num_epochs} │ Loss: {metrics[epoch]['loss_amp']:.4f}")

    return state, metrics

In [49]:
# ---- hyperparameters ----
batch_size    = 128
visible_units = 10
hidden_units  = visible_units
k_steps       = 5
lr            = 1e-2
num_epochs    = 50
chains        = batch_size

key_seed = PRNGKey(42)
key, key_params, key_chains, key_dummy = jax.random.split(key_seed, 4)

model_amp = PcdRBM(visible_units, hidden_units, k=k_steps, n_chains=chains)
batch_dummy = jnp.zeros((batch_size, visible_units), dtype=jnp.float32)
chain_dummy = jnp.zeros((chains, visible_units), dtype=jnp.float32)
variables_amp = model_amp.init({"params": key_params, "pcd_init": key_chains}, batch_dummy, key_dummy)

optimizer_amp = optax.adam(learning_rate=lr)
state_amp = PcdTrainState.create(
    apply_fn=model_amp.apply,
    params=variables_amp["params"],
    tx=optimizer_amp,
    pcd_state=variables_amp["pcd_state"])

In [50]:
train_amp_rbm(state_amp, loader_amp, num_epochs, key)

TypeError: iteration over a 0-d array