In [None]:
import jax
import jax.lax
from jax.random import PRNGKey
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from flax import linen as nn

from typing import Optional, Tuple, Dict, Any, Sequence


data_dir = "./data"
print(f"Data resides in        : {data_dir}")

In [None]:
class RBM(nn.Module):
    n_visible: int
    n_hidden: int
    k: int = 1
    n_chains: int = 1000

    @nn.compact
    def __call__(self, data_batch: jnp.ndarray, aux_vars: Optional[Dict[str, Any]] = None) -> Tuple[jnp.ndarray, Dict[str, Any]]:
        W = self.param("W", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        b = self.param("b", nn.initializers.zeros,        (self.n_visible,))
        c = self.param("c", nn.initializers.zeros,        (self.n_hidden,))
        key = aux_vars["key"]

        key, bern_key = jax.random.split(key, 2)
        v_chain_batch = jax.random.bernoulli(bern_key, p=0.5, shape=(self.n_chains, self.n_visible)).astype(jnp.float32)
        model_batch, key = self._gibbs_sample(W, b, c, v_chain_batch, key, k=self.k)
        model_batch = jax.lax.stop_gradient(model_batch)

        free_energy_data = self._free_energy(W, b, c, data_batch)
        free_energy_model = self._free_energy(W, b, c, model_batch)

        loss = jnp.mean(free_energy_data) - jnp.mean(free_energy_model)
        aux_vars["key"] = key

        return loss, aux_vars

    @staticmethod
    def _free_energy(W, b, c, v_batch):
        visible_term = jnp.dot(v_batch, b)
        hidden_term  = jnp.sum(jax.nn.softplus(v_batch @ W + c), axis=-1)
        free_energy = -visible_term - hidden_term
        return free_energy

    @staticmethod
    def _gibbs_step(i, state, W, b, c, T=1.0):
        v_batch, key = state
        key, h_key, v_key = jax.random.split(key, 3)

        h_logits = (v_batch @ W + c) / T
        h_probs = jax.nn.sigmoid(h_logits)
        h = jax.random.bernoulli(h_key, h_probs).astype(jnp.float32)

        v_logits = (h @ W.T + b) / T
        v_probs = jax.nn.sigmoid(v_logits)
        v = jax.random.bernoulli(v_key, v_probs).astype(jnp.float32)

        return v, key

    @staticmethod
    def _gibbs_sample(W, b, c, v_batch, rng, k=1, T=1.0):
        body_fun = lambda i, state: RBM._gibbs_step(i, state, W, b, c, T)
        v_final, key = jax.lax.fori_loop(0, k, body_fun, (v_batch, rng))
        return v_final, key

    @staticmethod
    def _annealing_step(i, state, W, b, c, T_schedule):
        v, rng = state
        T = T_schedule[i]
        v_next, rng_next = RBM._gibbs_sample(W, b, c, v, rng, k=1, T=T)
        return v_next, rng_next

    @nn.nowrap
    def generate(self, n_samples: int, T_schedule: jnp.ndarray, rng: PRNGKey) -> jnp.ndarray:
        W = self.variables["params"]["W"]
        b = self.variables["params"]["b"]
        c = self.variables["params"]["c"]

        rng, init_key = jax.random.split(rng)
        v = jax.random.bernoulli(init_key, p=0.5, shape=(n_samples, self.n_visible)).astype(jnp.float32)
        state = (v, rng)

        body_fun = lambda i, s: RBM._annealing_step(i, s, W, b, c, T_schedule)
        v_final, _ = jax.lax.fori_loop(0, len(T_schedule), body_fun, state)
        return v_final

In [None]:
@jax.jit
def train_step_amp(
        state: TrainState,
        batch_dict: Dict[str, jnp.ndarray],
        aux_vars: Dict[str, Any]) -> Tuple[TrainState, jnp.ndarray, Dict[str, Any]]:

    if len(batch_dict) != 1:
        raise ValueError("Batch dictionary must contain exactly one entry.")

    (key, batch), = batch_dict.items()
    if set(key) != {'Z'}:
        raise ValueError(f"Batch key must consist only of 'Z', got: {key}")

    loss_fn = lambda params: state.apply_fn({'params': params}, batch, aux_vars)
    value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

    (loss, aux_vars), grads = value_and_grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss, aux_vars


def train_amp_rbm(
        state: TrainState,
        loader: MultiBasisDataLoader,
        num_epochs: int,
        rng: PRNGKey) -> Tuple[TrainState, Dict[int, float], PRNGKey]:

    metrics = {}
    aux_vars = {"key": rng}

    for epoch in range(num_epochs):
        tot_loss = 0.0
        batches = 0

        for data_batch in loader:
            state, loss, aux_vars = train_step_amp(state, data_batch, aux_vars)
            tot_loss += loss
            batches += 1

        metrics[epoch] = {"loss_amp": float(tot_loss / batches)}
        print(f"Epoch {epoch+1}/{num_epochs} │ Loss: {metrics[epoch]['loss_amp']:.4f}")

    return state, metrics, aux_vars["key"]

In [None]:
# ---- hyperparameters ----
batch_size    = 128
visible_units = 10
hidden_units  = 30
k_steps       = 10
lr            = 1e-2
num_epochs    = 50
chains        = 1000

random_seed = PRNGKey(42)
rng, init_key = jax.random.split(random_seed)

# model initialization
model_amp = RBM(n_visible=visible_units, n_hidden=hidden_units, k=k_steps, n_chains=chains)
dummy_batch = jnp.zeros((batch_size, visible_units), dtype=jnp.float32)
variables_amp = model_amp.init(init_key, dummy_batch, {"key": rng})

optimizer_amp = optax.adam(learning_rate=lr)
state_amp = TrainState.create(apply_fn=model_amp.apply, params=variables_amp["params"], tx=optimizer_amp)