In [27]:
import re
from pathlib import Path

import numpy as np

import jax
import jax.lax
from jax.random import PRNGKey
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from flax import linen as nn

import matplotlib.pyplot as plt

from typing import Optional, Tuple, Dict, Any, Sequence


data_dir = "./data"
print(f"Data resides in        : {data_dir}")

Data resides in        : ./data


In [28]:
class MultiBasisDataLoader:
    def __init__(self, data_dict: dict[str, jnp.ndarray],
                 batch_size: int = 128,
                 shuffle: bool = True,
                 drop_last: bool = False,
                 seed: int = 0):
        lengths = [len(v) for v in data_dict.values()]
        if len(set(lengths)) != 1:
            raise ValueError(f"All arrays must have the same length, got: {lengths}")

        self.data = data_dict
        self.n = lengths[0]
        self.bs = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.rng = np.random.default_rng(seed)

        self.idx_slices = [
            (i, i + batch_size)
            for i in range(0, self.n, batch_size)
            if not drop_last or i + batch_size <= self.n
        ]

    def __iter__(self):
        self.order = np.arange(self.n)
        if self.shuffle:
            self.rng.shuffle(self.order)
        self.slice_idx = 0
        return self

    def __next__(self):
        if self.slice_idx >= len(self.idx_slices):
            raise StopIteration
        s, e = self.idx_slices[self.slice_idx]
        self.slice_idx += 1
        return {k: v[self.order[s:e]] for k, v in self.data.items()}


def load_measurements(folder: str, file_pattern: str = "w_*.txt") -> dict[str, jnp.ndarray]:
    out: dict[str, jnp.ndarray] = {}

    for fp in Path(folder).glob(file_pattern):
        basis = fp.stem.split("_")[2]

        bitstrings = []
        with fp.open() as f:
            for line in f:
                bitstring = np.fromiter((c.islower() for c in line.strip()), dtype=np.float32)
                bitstrings.append(bitstring)

        arr = jnp.asarray(np.stack(bitstrings))
        if basis in out:
            out[basis] = jnp.concatenate([out[basis], arr], axis=0)
        else:
            out[basis] = arr

    return out

In [29]:
data_dict = load_measurements("data/", "w_*.txt")

keys_amp = [k for k in data_dict if re.fullmatch(r"^Z+$", k)]
keys_pha = [k for k in data_dict if re.fullmatch(r"^(?!Z+$).*", k)]
dict_amp = {k: data_dict[k] for k in keys_amp}
dict_pha = {k: data_dict[k] for k in keys_pha}

loader_amp = MultiBasisDataLoader(dict_amp, batch_size=128)
loader_pha = MultiBasisDataLoader(dict_pha, batch_size=128)

In [49]:
import jax, jax.numpy as jnp
from jax.random import PRNGKey
from flax import linen as nn
from typing import Tuple

class RBM(nn.Module):
    n_visible: int
    n_hidden: int
    k: int = 1
    T: float = 1.0

    def setup(self):
        self.W = self.param("W", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        self.b = self.param("b", nn.initializers.zeros, (self.n_visible,))
        self.c = self.param("c", nn.initializers.zeros, (self.n_hidden,))

    def __call__(self, data: jnp.ndarray, aux_vars: Dict[str, Any]) -> Tuple[jnp.ndarray, Dict[str, Any]]:
        gibbs_step_fn = lambda i, s: self._gibbs_step(s, self.W, self.b, self.c, self.T)
        gibbs_chain, key = jax.lax.fori_loop(0, self.k, gibbs_step_fn, (aux_vars["gibbs_chain"], aux_vars["key"]))
        gibbs_chain = jax.lax.stop_gradient(gibbs_chain) # cut off gradients after Gibbs sampling

        loss = jnp.mean(self._free_energy(data)) - jnp.mean(self._free_energy(gibbs_chain))
        return loss, {"gibbs_chain": gibbs_chain, "key": key}

    def _free_energy(self, v):
        return -(v @ self.b) - jnp.sum(jax.nn.softplus(v @ self.W + self.c), -1)

    @staticmethod
    def _gibbs_step(state, W, b, c, T):
        v, key = state
        key, h_key, v_key = jax.random.split(key, 3)
        h = jax.random.bernoulli(h_key, jax.nn.sigmoid((v @ W + c)/T)).astype(jnp.float32)
        v = jax.random.bernoulli(v_key, jax.nn.sigmoid((h @ W.T + b)/T)).astype(jnp.float32)
        return v, key

In [50]:
from flax import struct
from flax.training.train_state import TrainState
from flax.core import FrozenDict
from typing import Any, Dict, Tuple

@jax.jit
def train_step_amp(
        state: TrainState,
        batch_dict: Dict[str, jnp.ndarray],
        gibbs_chain: jnp.ndarray,
        key: PRNGKey) -> Tuple[TrainState, jnp.ndarray, jnp.ndarray, PRNGKey]:

    if len(batch_dict) != 1:
        raise ValueError("Batch dictionary must contain exactly one entry.")

    (basis_key, batch), = batch_dict.items()
    if set(basis_key) != {'Z'}:
        raise ValueError(f"Batch key must consist only of 'Z', got: {basis_key}")

    aux_vars = {"gibbs_chain": gibbs_chain, "key": key}
    loss_fn = lambda params: state.apply_fn({'params': params}, batch, aux_vars)
    value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

    (loss, aux_vars), grads = value_and_grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, aux_vars["gibbs_chain"], aux_vars["key"]

def train_amp_rbm(
        state: TrainState,
        loader: MultiBasisDataLoader,
        gibbs_chain: jnp.ndarray,
        num_epochs: int,
        key: PRNGKey) -> Tuple[TrainState, Dict[int, float]]:

    metrics = {}

    for epoch in range(num_epochs):
        tot_loss = 0.0
        batches = 0

        for batch_dict in loader:
            state, loss, gibbs_chain, key = train_step_amp(state, batch_dict, gibbs_chain, key)
            tot_loss += loss
            batches += 1

        metrics[epoch] = {"loss_amp": float(tot_loss / batches)}
        print(f"Epoch {epoch+1}/{num_epochs} │ Loss: {metrics[epoch]['loss_amp']:.4f}")

    return state, metrics

In [51]:
# ---- hyperparameters ----
batch_size    = 128
visible_units = 10
hidden_units  = visible_units * 3
k_steps       = 5
lr            = 1e-2
num_epochs    = 100
chains        = batch_size

key_seed = PRNGKey(42)
key, key_params, key_chains, key_dummy = jax.random.split(key_seed, 4)

model_amp = RBM(visible_units, hidden_units, k=k_steps)
batch_dummy = jnp.zeros((batch_size, visible_units), dtype=jnp.float32)
aux_vars_dummy = {"gibbs_chain": jnp.zeros((batch_size, visible_units), dtype=jnp.float32), "key": key_dummy}
variables_amp = model_amp.init({"params": key_params}, batch_dummy, aux_vars_dummy)

optimizer_amp = optax.adam(learning_rate=lr)
state_amp = TrainState.create(apply_fn=model_amp.apply, params=variables_amp["params"], tx=optimizer_amp)
gibbs_chain = jax.random.bernoulli(key_chains, p=0.5, shape=(chains, visible_units)).astype(jnp.float32)

In [55]:
%time state_amp, metrics_amp = train_amp_rbm(state_amp, loader_amp, gibbs_chain, num_epochs, key)

Epoch 1/100 │ Loss: 0.0467
Epoch 2/100 │ Loss: 0.3607
Epoch 3/100 │ Loss: 0.1674
Epoch 4/100 │ Loss: 0.0638
Epoch 5/100 │ Loss: 0.2074
Epoch 6/100 │ Loss: 0.0881
Epoch 7/100 │ Loss: 0.0958
Epoch 8/100 │ Loss: 0.2018
Epoch 9/100 │ Loss: 0.1452
Epoch 10/100 │ Loss: 0.0755
Epoch 11/100 │ Loss: 0.2002
Epoch 12/100 │ Loss: 0.0791
Epoch 13/100 │ Loss: 0.1727
Epoch 14/100 │ Loss: 0.0488
Epoch 15/100 │ Loss: 0.0535
Epoch 16/100 │ Loss: 0.1995
Epoch 17/100 │ Loss: 0.1300
Epoch 18/100 │ Loss: 0.2105
Epoch 19/100 │ Loss: 0.1713
Epoch 20/100 │ Loss: -0.0011
Epoch 21/100 │ Loss: 0.1648
Epoch 22/100 │ Loss: 0.0858
Epoch 23/100 │ Loss: 0.1657
Epoch 24/100 │ Loss: 0.1111
Epoch 25/100 │ Loss: 0.1437
Epoch 26/100 │ Loss: 0.1500
Epoch 27/100 │ Loss: 0.2081
Epoch 28/100 │ Loss: 0.0744
Epoch 29/100 │ Loss: 0.1683
Epoch 30/100 │ Loss: 0.1901
Epoch 31/100 │ Loss: 0.0507
Epoch 32/100 │ Loss: 0.1598
Epoch 33/100 │ Loss: 0.1303
Epoch 34/100 │ Loss: 0.1198
Epoch 35/100 │ Loss: 0.3186
Epoch 36/100 │ Loss: 0.1426
