In [1]:
import functools
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from flax.training import train_state
from jax.random import PRNGKey


In [2]:
class RBM(nn.Module):
    """Restricted Boltzmann Machine with CD‑k or PCD‑k training."""
    n_visible: int
    n_hidden: int
    k: int = 1  # CD-k / PCD-k

    # ─────────────────────── model forward ────────────────────────
    @nn.compact
    def __call__(self,
                 data_batch: jnp.ndarray,
                 v_persistent: jnp.ndarray,
                 rng: PRNGKey) -> tuple[jnp.ndarray, dict]:
        W = self.param("W", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        b = self.param("b", nn.initializers.zeros,         (self.n_visible,))
        c = self.param("c", nn.initializers.zeros,         (self.n_hidden,))
        params = {"W": W, "b": b, "c": c}

        # ── positive & negative phases ────────────────────────────
        v_k, key = self._gibbs_sample(params, v_persistent, rng, k=self.k)
        v_k = jax.lax.stop_gradient(v_k)

        free_e_data  = self._free_energy(params, data_batch)
        free_e_model = self._free_energy(params, v_k)

        pcd_loss = jnp.mean(free_e_data) - jnp.mean(free_e_model)
        aux      = {"v_persistent": v_k, "key": key}
        return pcd_loss, aux

    # ─────────────────────── statics ──────────────────────────────
    @staticmethod
    def _free_energy(params, v):
        W, b, c   = params["W"], params["b"], params["c"]
        v_term    = jnp.dot(v, b)
        h_term    = jnp.sum(jax.nn.softplus(v @ W + c), axis=-1)
        return -(v_term + h_term)

    @staticmethod
    def _gibbs_step(_, state, params, T=1.0):
        v, key      = state
        W, b, c     = params["W"], params["b"], params["c"]
        key, hk, vk = jax.random.split(key, 3)

        h_prob = jax.nn.sigmoid((v @ W + c) / T)
        h      = jax.random.bernoulli(hk, h_prob).astype(jnp.float32)

        v_prob = jax.nn.sigmoid((h @ W.T + b) / T)
        v      = jax.random.bernoulli(vk, v_prob).astype(jnp.float32)
        return v, key

    @staticmethod
    def _gibbs_sample(params, v0, rng, k=1, T=1.0):
        body = lambda i, st: RBM._gibbs_step(i, st, params, T)
        v_k, key = jax.lax.fori_loop(0, k, body, (v0, rng))
        return v_k, key

    # ────────────────────── sampling util ─────────────────────────
    @nn.nowrap
    def generate(self,
                 params: dict,
                 n_samples: int,
                 T_schedule: jnp.ndarray,
                 rng: PRNGKey) -> jnp.ndarray:
        rng, key = jax.random.split(rng)
        v = jax.random.bernoulli(key, p=0.5,
                                 shape=(n_samples, self.n_visible)).astype(jnp.float32)
        state = (v, rng)

        step = lambda i, st: RBM._gibbs_step(i, st, params, T_schedule[i])
        v_fin, _ = jax.lax.fori_loop(0, len(T_schedule), step, state)
        return v_fin


In [3]:
class RBMTrainState(train_state.TrainState):
    """Bundles params and opt-state (Flax API)."""
    pass


@functools.partial(jax.jit)
def train_step(state: RBMTrainState,
               data_batch: jnp.ndarray,
               v_persistent: jnp.ndarray,
               key: PRNGKey):
    pcd_loss_fn = lambda params: state.apply_fn(
        {"params": params}, data_batch, v_persistent, key)

    (loss, aux), grads = jax.value_and_grad(pcd_loss_fn,
                                            has_aux=True)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss, aux["v_persistent"], aux["key"]


def train_rbm(state: RBMTrainState,
              train_loader,
              num_epochs: int,
              rng: PRNGKey,
              pcd_reset: int = 5,
              scheduler=None):
    """Training loop for the RBM."""
    metrics = {}
    for epoch in range(num_epochs):
        # initialize persistent chain once per epoch
        rng, sk = jax.random.split(rng)
        # shape = (batch_size, n_visible), grab from the first batch
        first_batch = next(iter(train_loader))
        v_persistent = jax.random.bernoulli(sk, p=0.5,
                                            shape=first_batch.shape).astype(jnp.float32)

        tot_loss, n_batches = 0.0, 0
        for b_idx, data in enumerate(train_loader):
            # only reset if pcd_reset is an integer
            if (pcd_reset is not None) and (b_idx % pcd_reset == 0):
                rng, sk = jax.random.split(rng)
                v_persistent = jax.random.bernoulli(sk, p=0.5,
                                                    shape=data.shape).astype(jnp.float32)

            state, loss, v_persistent, rng = train_step(state,
                                                        data,
                                                        v_persistent,
                                                        rng)
            tot_loss += loss
            n_batches += 1

        avg = tot_loss / n_batches
        metrics[epoch] = {"free_energy_loss": float(avg)}
        print(f"Epoch [{epoch+1}/{num_epochs}] – FE-loss: {avg:.4f}")
    return state, metrics, rng


In [4]:
def get_cosine_schedule(T_high: float,
                        T_low: float,
                        n_steps: int) -> jnp.ndarray:
    """Cosine annealed temperature schedule."""
    steps = jnp.arange(n_steps, dtype=jnp.float32)
    cos   = 0.5 * (1 + jnp.cos(jnp.pi * steps / (n_steps - 1)))
    return T_low + (T_high - T_low) * cos
