In [7]:
import numpy as np

import jax
import jax.lax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state

import functools

from load_mnist import download_mnist_if_needed, load_images, load_labels

import matplotlib.pyplot as plt


data_dir = "./data"
device = jax.devices('cpu')[0]

print(f"Data resides in        : {data_dir}")
print(f"Training model on      : {str(device)}")

Data resides in        : ./data
Training model on      : TFRT_CPU_0


In [8]:
def preprocess(x):
    x = x.astype(np.float32) / 255.0    # normalize to [0, 1]
    x = x > 0.5                         # binarize
    x = x.reshape(x.shape[0], -1)       # flatten for RBM
    x = jnp.array(x, dtype=jnp.float32) # use jax numpy array of dtype float32, because RBM has float32 params
    return x


data_paths = download_mnist_if_needed(root=data_dir, train_only=True)
x_train_raw = load_images(data_paths['train_images'])
y_train = load_labels(data_paths['train_labels'])

x_train = preprocess(x_train_raw)
print(f"x_train dtype: {x_train.dtype}, shape: {x_train.shape}")
print(f"x_train min: {x_train.min()}, max: {x_train.max()}")

x_train dtype: float32, shape: (60000, 784)
x_train min: 0.0, max: 1.0


In [9]:
import jax
import jax.numpy as jnp
import flax.linen as nn

class RBM(nn.Module):
    n_visible: int
    n_hidden: int

    @nn.compact
    def __call__(self, v):
        # In Flax, __call__ is the main computation, analogous to PyTorch’s forward().
        # We have moved the implementation of the free energy here.
        #
        # Even though __call__ is invoked twice during contrastive divergence, the
        # parameters are only initialized once on the first call via @nn.compact.
        W = self.param("W", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        b = self.param("b", nn.initializers.zeros, (self.n_visible,))
        c = self.param("c", nn.initializers.zeros, (self.n_hidden,))

        # Assumes v has shape (batch_size, n_visible); ensure it before calling
        visible_term = jnp.dot(v, b)
        hidden_term = jnp.sum(jax.nn.softplus(v @ W + c), axis=-1)
        return -visible_term - hidden_term

    @staticmethod
    def _gibbs_step(i, state, params, T=1.0):
        v, key = state
        W, b, c = params["W"], params["b"], params["c"]

        key, h_key, v_key = jax.random.split(key, 3)

        h_probs = jax.nn.sigmoid((v @ W + c) / T)
        h = jax.random.bernoulli(h_key, h_probs).astype(jnp.float32)

        v_probs = jax.nn.sigmoid((h @ W.T + b) / T)
        v = jax.random.bernoulli(v_key, v_probs).astype(jnp.float32)

        return v, key

    @staticmethod
    def gibbs_sample(params, v0, rng, k=1, T=1.0):
        # The fori_loop enables JIT compilation of loops. It basically unrolls the loop over the fixed length k.
        return jax.lax.fori_loop(0, k, lambda i, state: RBM._gibbs_step(i, state, params, T), (v0, rng))[0]

In [5]:
from flax.training import train_state
from flax.core import FrozenDict


class RBMTrainState(train_state.TrainState):
    """A value object bundling parameters to be passed between training steps.
    It holds all the necessary state information (model parameters, optimizer state, step count, etc.).
    Since it is immutable, the training function needs to return a new instance after each update step.
    """
    pass


# this function will be compiled by the function below
def pcd_loss(
        params: FrozenDict, # parameters to calculate loss for
        apply_fn: callable, # The model's apply function stored in the TrainState
        data_batch: jnp.ndarray,
        current_fantasy_particles: jnp.ndarray,
        key_loss: jax.random.PRNGKey,
        k: int              # Number of Gibbs sampling steps
) -> jnp.ndarray:

    v_k = RBM.gibbs_sample(params, current_fantasy_particles, key_loss, k)
    # even though the gibbs_sample function is perfectly differentiable, we are only interested in v_k as a sample
    v_k = jax.lax.stop_gradient(v_k)

    # retrieves the models parameters and makes sure they are available inside the specific function
    fe_data = apply_fn({'params': params}, data_batch)
    fe_model = apply_fn({'params': params}, v_k)

    loss = jnp.mean(fe_data) - jnp.mean(fe_model)
    return loss

In [None]:



@functools.partial(jax.jit, static_argnames=['k'])
def train_step(
        state: RBMTrainState,
        data_batch: jnp.ndarray,
        current_fantasy_particles: jnp.ndarray,
        key: jax.random.PRNGKey,
        k: int):
    """
    Performs a single, JIT-compilable training step using PCD.

    Args:
        state: The current RBMTrainState value object containing parameters,
               optimizer state, apply_fn, etc.
        data_batch: A batch of input data (visible units).
        current_fantasy_particles: The persistent chain samples from the previous step or a new one after reset.
        key: A JAX PRNG key for this step's random operations.
        k: The number of Gibbs sampling steps (static for JIT compilation).

    Returns:
        A tuple containing:
          - new_state: The updated RBMTrainState value object.
          - loss_value: The scalar loss calculated for this step (for metrics).
          - new_fantasy_particles: The updated persistent chain samples for the next step.
    """

    key_loss, key_update = jax.random.split(key)

    # Define the loss function. It captures 'state', 'data_batch',
    # 'current_fantasy_particles', 'key_loss', and 'k' from the outer scope.
    # It takes only 'params' as input, which is what jax.grad needs.
    def pcd_loss_fn(params):
        # --- Negative Phase Calculation ---
        # Generate k samples starting from the persistent chain (fantasy particles)
        # using the *current* parameters 'params'.
        v_k = RBM.gibbs_sample(params, current_fantasy_particles, key_loss, k)
        # Stop gradients from flowing back through the Gibbs sampling process itself.
        # The gradient should only depend on how params affect the free energy.
        v_k = jax.lax.stop_gradient(v_k)

        # --- Free Energy Calculation ---
        # Calculate free energy for the data (positive phase)
        fe_data = state.apply_fn({'params': params}, data_batch, method='free_energy')
        # Calculate free energy for the model samples (negative phase)
        fe_model = state.apply_fn({'params': params}, v_k, method='free_energy')

        # --- Loss Calculation ---
        # The PCD loss is the difference in mean free energies.
        loss = jnp.mean(fe_data) - jnp.mean(fe_model)
        return loss

    # Calculate the loss value and the gradients w.r.t. 'state.params'
    loss_value, grads = jax.value_and_grad(pcd_loss_fn)(state.params)

    # Apply the gradients to the state.
    # This uses the optimizer defined in state.tx and updates
    # state.params, state.opt_state, and state.step.
    # It returns a NEW state object.
    new_state = state.apply_gradients(grads=grads)

    # Update the fantasy particles for the *next* training step.
    # Important: Sample using the parameters *before* the gradient update (`state.params`),
    # consistent with how v_k was generated for the loss calculation and mirroring
    # the detach() -> step() order in PyTorch.
    new_fantasy_particles = RBM.gibbs_sample(state.params, current_fantasy_particles, key_update, k)
    # Ensure no gradients accidentally flow through this state update path either.
    new_fantasy_particles = jax.lax.stop_gradient(new_fantasy_particles)

    # Return the new state, the loss, and the updated particles
    return new_state, loss_value, new_fantasy_particles



In [8]:
@functools.partial(jax.jit, static_argnames=("rbm", "k", "T"))
def cd_loss(rbm, params, v0, rng, k=1, T=1.0):
    v_k = rbm.gibbs_sample(params, v0, rng, k, T)
    fe_data = rbm.apply({'params': params}, v0)
    fe_model = rbm.apply({'params': params}, v_k)
    return jnp.mean(fe_data) - jnp.mean(fe_model)

In [9]:
key = jax.random.PRNGKey(0)
rbm = RBM(n_visible=784, n_hidden=128)
v0 = jax.random.bernoulli(key, shape=(32, 784)).astype(jnp.float32)
params = rbm.init(key, v0)['params']

loss = cd_loss(rbm, params, v0, key)
print("CD-1 Loss:", loss)

CD-1 Loss: 0.7732544
