In [3]:
%reload_ext autoreload
%autoreload 2

from lib.data_loading import load_measurements, MixedDataLoader

####

from pathlib import Path
from typing import Optional, Tuple, Dict, Any, Sequence, Callable

import jax
import jax.lax
from jax.random import PRNGKey
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from flax import linen as nn

import matplotlib.pyplot as plt


data_dir = "data"
model_dir = Path("./models")
model_dir.mkdir(parents=True, exist_ok=True)

print(f"Data resides in        : {data_dir}")
print(f"Model will be saved to : {model_dir}")

Data resides in        : data
Model will be saved to : models


In [4]:
class ClampedRBM(nn.Module):
    num_visible: int
    num_hidden: int
    k: int = 1
    T: float = 1.0

    def setup(self):
        self.W = self.param("W", nn.initializers.normal(0.01), (self.num_visible, self.num_hidden))
        self.b = self.param("b", nn.initializers.zeros, (self.num_visible,))
        self.c = self.param("c", nn.initializers.zeros, (self.num_hidden,))

    def _free_energy(self, v: jnp.ndarray) -> jnp.ndarray:
        return -(v @ self.b) - jnp.sum(jax.nn.softplus(v @ self.W + self.c), -1)

    @staticmethod
    def _gibbs_step(state, W, b, c, T, clamp_mask, clamp_vals):
        v, key = state
        key, h_key, v_key = jax.random.split(key, 3)

        # for simplicity reasons we sample the full visible vector and overwrite the clamped units
        h = jax.random.bernoulli(h_key, jax.nn.sigmoid((v @ W + c) / T)).astype(jnp.float32)
        v = jax.random.bernoulli(v_key, jax.nn.sigmoid((h @ W.T + b) / T)).astype(jnp.float32)

        v = jnp.where(clamp_mask, clamp_vals, v)
        return v, key

    def __call__(self, data: jnp.ndarray, aux_vars: Dict[str, Any]) -> Tuple[jnp.ndarray, Dict[str, Any]]:
        key = aux_vars["key"]
        clamp_mask = aux_vars["clamp_mask"]

        batch_size = data.shape[0]
        v_data = data.reshape(batch_size, -1).astype(jnp.float32)
        clamp_mask = jnp.broadcast_to(clamp_mask, v_data.shape)
        clamp_vals = jnp.where(clamp_mask, v_data, 0.)

        key, init_key = jax.random.split(key)
        gibbs_chain = jax.random.bernoulli(init_key, p=0.5, shape=v_data.shape).astype(jnp.float32)
        gibbs_chain = jnp.where(clamp_mask, clamp_vals, gibbs_chain)

        step_fn = lambda i, s: self._gibbs_step(s, self.W, self.b, self.c, self.T, clamp_mask, clamp_vals)
        vk, key = jax.lax.fori_loop(0, self.k, step_fn, (gibbs_chain, key))
        vk = jax.lax.stop_gradient(vk)

        fe_data  = jnp.mean(self._free_energy(v_data))
        fe_model = jnp.mean(self._free_energy(vk))

        l2_regularization = jnp.sum(self.W**2) + jnp.sum(self.b**2) + jnp.sum(self.c**2)
        loss  = fe_data - fe_model + aux_vars["l2_strength"] * l2_regularization

        aux_vars = { "key": key, "free_energy_data": fe_data, "free_energy_model": fe_model }

        return loss, aux_vars

    @nn.nowrap
    def generate(self,
                 clamp_vals: jnp.ndarray,  # (B, V) flattened or (B, …) to be flattened
                 clamp_mask: jnp.ndarray,  # (1|B, V) bool
                 T_schedule: jnp.ndarray,  # (L,)
                 key: PRNGKey) -> jnp.ndarray:

        clamp_vals = clamp_vals.reshape(clamp_vals.shape[0], -1).astype(jnp.float32)
        clamp_mask = jnp.broadcast_to(clamp_mask.astype(bool), clamp_vals.shape)

        key, init_key = jax.random.split(key)
        v = jax.random.bernoulli(init_key, 0.5, shape=clamp_vals.shape).astype(jnp.float32)
        v = jnp.where(clamp_mask, clamp_vals, v)

        def step(i, s):
            return self._gibbs_step(s, self.W, self.b, self.c,
                                    T_schedule[i], clamp_mask, clamp_vals)

        v_final, _ = jax.lax.fori_loop(0, len(T_schedule), step, (v, key))
        return v_final  # caller can reshape back if desired

In [5]:
@jax.jit
def train_step(
        state: TrainState,
        batch: jnp.ndarray,
        aux_vars: Dict[str, Any]) -> Tuple[TrainState, jnp.ndarray, Dict[str, Any]]:

    loss_fn = lambda params: state.apply_fn({"params": params}, batch, aux_vars)

    (loss, aux_vars), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)

    return state, loss, aux_vars


def train(
        state: TrainState,
        loader: MixedDataLoader,
        num_epochs: int,
        key: PRNGKey,
        l2_strength: float,
        lr_schedule_fn: Callable[[int], float]) -> Tuple[TrainState, Dict[int, float]]:

    metrics: Dict[int, Any] = {}
    clamp_mask = None

    for epoch in range(num_epochs):
        tot_loss, batches = 0.0, 0

        for batch in loader:
            if clamp_mask is None:
                batch_size, num_qubits, _ = batch.shape
                clamp_mask = jnp.ones((batch_size, num_qubits, 3), dtype=bool)
                clamp_mask = clamp_mask.at[:, :, 0].set(False)

            key, subkey = jax.random.split(key)
            aux_vars = {
                "key": subkey,
                "clamp_mask": clamp_mask,
                "l2_strength": l2_strength,
            }

            state, loss, aux_out = train_step(state, batch, aux_vars)
            key = aux_out["key"]
            free_E_model = aux_out["free_energy_model"]
            free_E_data = aux_out["free_energy_data"]
            model_samples = aux_out["model_samples"]

            tot_loss += float(loss)
            batches  += 1

        avg_loss = tot_loss / batches
        lr       = lr_schedule_fn(state.step)

        metrics[epoch] = dict(
            loss=avg_loss,
            free_energy_model=free_E_model,
            free_energy_data=free_E_data,
            lr=lr
        )

        print(f"Epoch {epoch+1}/{num_epochs} │ "
              f"Loss: {avg_loss:+.4f} │ "
              f"Free En. Model: {free_E_model:.4f} │ "
              f"Free En. Data: {free_E_data:.4f} │ "
              f"Learning Rate: {lr:.5f}")

    return state, metrics

In [None]:
# next steps
#   - train only on Z basis
#   - finetune on mixed basis