In [1]:
import numpy as np

import jax
import jax.lax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state

from load_mnist import download_mnist_if_needed, load_images, load_labels

import matplotlib.pyplot as plt


data_dir = "./data"
device = jax.devices('cpu')[0]

print(f"Data resides in        : {data_dir}")
print(f"Training model on      : {str(device)}")

Data resides in        : ./data
Training model on      : TFRT_CPU_0


In [2]:
def preprocess(x):
    x = x.astype(np.float32) / 255.0    # normalize to [0, 1]
    x = x > 0.5                         # binarize
    x = x.reshape(x.shape[0], -1)       # flatten for RBM
    x = jnp.array(x, dtype=jnp.float32) # use jax numpy array of dtype float32, because RBM has float32 params
    return x


data_paths = download_mnist_if_needed(root=data_dir, train_only=True)
x_train_raw = load_images(data_paths['train_images'])
y_train = load_labels(data_paths['train_labels'])

x_train = preprocess(x_train_raw)
print(f"x_train dtype: {x_train.dtype}, shape: {x_train.shape}")
print(f"x_train min: {x_train.min()}, max: {x_train.max()}")

x_train dtype: float32, shape: (60000, 784)
x_train min: 0.0, max: 1.0


In [3]:
class RBM(nn.Module):
    # The following variables form part of the constructor since this is a dataclass.
    n_visible: int
    n_hidden: int

    def _sample_hidden(self, key: jax.random.PRNGKey, v, W, c, T=1.0):
        # Random key is passed explicitly to ensure purity and JIT-compatibility
        logits = (v @ W + c) / T
        h_probs = jax.nn.sigmoid(logits)
        h_sample = jax.random.bernoulli(key, h_probs)
        return h_sample.astype(jnp.float32), h_probs

    def _sample_visible(self, key: jax.random.PRNGKey, h, W, b, T=1.0):
        # Random key is passed explicitly to ensure purity and JIT-compatibility
        logits = (h @ W.T + b) / T
        v_probs = jax.nn.sigmoid(logits)
        v_sample = jax.random.bernoulli(key, v_probs)
        return v_sample.astype(jnp.float32), v_probs

    def sample_gibbs(self, v0_sample, k=1, T=1.0):
        # JAX follows a functional programming paradigm with pure functions. However,
        # methods like this one use self.param() and self.make_rng(), which require access
        # to Flax's internal scope. These scopes are only available during .apply() or .init(),
        # which safely inject parameters and RNGs while preserving functional purity externally.
        loop_key = self.make_rng("sample")
        W, b, c = self.param("W"), self.param("b"), self.param("c")

        def gibbs_step_body(i, carry):
            v_carry, key_carry = carry
            key_carry, hidden_key, visible_key = jax.random.split(key_carry, 3)
            h, _ = self._sample_hidden(hidden_key, v_carry, W, c, T)
            v_next, _ = self._sample_visible(visible_key, h, W, b, T)
            return (v_next, key_carry)

        initial_carry = (v0_sample, loop_key)
        final_carry = jax.lax.fori_loop(0, k, gibbs_step_body, initial_carry)
        final_v = final_carry[0]
        return final_v

    @nn.compact
    def __call__(self, v):
        # In Flax, __call__ is the main computation, analogous to PyTorch’s forward().
        # We have moved the implementation of the free energy here.
        #
        # Even though __call__ is invoked twice during contrastive divergence, the
        # parameters are only initialized once on the first call via @nn.compact.
        W = self.param("W", nn.initializers.normal(0.01), (self.n_visible, self.n_hidden))
        b = self.param("b", nn.initializers.zeros, (self.n_visible,))
        c = self.param("c", nn.initializers.zeros, (self.n_hidden,))

        # Assumes v has shape (batch_size, n_visible); ensure it before calling
        visible_term = jnp.dot(v, b)
        hidden_logits = v @ W + c
        hidden_term = jnp.sum(jax.nn.softplus(hidden_logits), axis=-1)
        return -visible_term - hidden_term

In [4]:
# test RBM model, make sure to call apply

key = jax.random.PRNGKey(0)
rbm = RBM(n_visible=28*28, n_hidden=256)
params = rbm.init(key, jnp.ones((1, 28*28)))['params']
v = jnp.ones((1, 28*28))
fe = rbm.apply({'params': params}, v)
print(fe)

[-179.0123]


In [10]:
import functools


@functools.partial(jax.jit, static_argnames=("rbm", "k", "T"))
def cd_loss(rbm, params, v0_sample, rng, k=1, T=1.0):
    # v_k ← Gibbs sample starting from data
    v_k = rbm.apply({'params': params}, v0_sample, method=rbm.sample_gibbs, rngs={'sample': rng}, k=k, T=T)

    # __call__ is already the free energy (negative log unnormalized prob)
    fe_data = rbm.apply({'params': params}, v0_sample)
    fe_model = rbm.apply({'params': params}, v_k)

    loss = jnp.mean(fe_data) - jnp.mean(fe_model)
    return loss

In [11]:
key = jax.random.PRNGKey(0)
batch_size = 32
n_visible = 784
n_hidden = 128

# Dummy binarized input
v0_sample = jax.random.bernoulli(key, p=0.5, shape=(batch_size, n_visible)).astype(jnp.float32)

# Init RBM
rbm = RBM(n_visible=n_visible, n_hidden=n_hidden)
params = rbm.init(key, v0_sample)['params']

# Run test
loss = cd_loss(rbm, params, v0_sample, key, k=1)
print("CD-1 Loss:", loss)


TypeError: param() missing 1 required positional argument: 'init_fn'

In [6]:
import functools
import jax
import jax.numpy as jnp
from flax.training import train_state
from flax.core import FrozenDict
import optax
from tqdm import trange

class TrainState(train_state.TrainState):
    params: FrozenDict
    rng: jax.random.PRNGKey
    particles: jnp.ndarray  # persistent fantasy particles

def free_energy(rbm, params, v):
    return rbm.apply({'params': params}, v)

@functools.partial(jax.jit, static_argnums=(0, 5))
def train_step(rbm, state, batch, k, weight_decay, pcd_reset):
    def loss_fn(params, rng, particles):
        v_data = batch
        v_model = jax.lax.cond(
            state.step % pcd_reset == 0,
            lambda _: jax.random.bernoulli(rng, p=0.5, shape=v_data.shape),
            lambda _: particles,
            operand=None
        )
        rng, subkey = jax.random.split(rng)
        v_model = rbm.apply({'params': params}, v_model, method=rbm.sample_gibbs, rngs={'sample': subkey}, k=k)
        fe_data = free_energy(rbm, params, v_data).mean()
        fe_model = free_energy(rbm, params, v_model).mean()
        loss = fe_data - fe_model
        l2_penalty = sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
        return loss + weight_decay * l2_penalty, (loss, rng, v_model)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (fe_loss, rng, new_particles)), grads = grad_fn(state.params, state.rng, state.particles)
    new_state = state.apply_gradients(grads=grads, rng=rng, particles=new_particles)
    return new_state, fe_loss

def train_rbm_jax(rbm, x_train, batch_size, num_epochs, k, lr, weight_decay, lr_decay, pcd_reset):
    rng = jax.random.PRNGKey(0)
    init_rng, sample_rng = jax.random.split(rng)
    init_data = jnp.ones((batch_size, rbm.n_visible), dtype=jnp.float32)
    variables = rbm.init(init_rng, init_data)
    params = variables['params']

    tx = optax.chain(
        optax.add_decayed_weights(weight_decay),
        optax.exponential_decay(lr, transition_steps=1, decay_rate=lr_decay, staircase=True),
        optax.adam(learning_rate=lr)
    )

    init_particles = jax.random.bernoulli(sample_rng, p=0.5, shape=(batch_size, rbm.n_visible)).astype(jnp.float32)

    state = TrainState.create(
        apply_fn=rbm.apply,
        params=params,
        tx=tx,
        rng=sample_rng,
        particles=init_particles
    )

    num_batches = x_train.shape[0] // batch_size
    metrics = {}

    for epoch in trange(num_epochs, desc="Training"):
        perm = jax.random.permutation(state.rng, x_train.shape[0])
        x_shuffled = x_train[perm]
        epoch_loss = 0.0

        for i in range(num_batches):
            batch = x_shuffled[i * batch_size:(i + 1) * batch_size]
            state, loss = train_step(rbm, state, batch, k, weight_decay, pcd_reset)
            epoch_loss += loss

        avg_loss = epoch_loss / num_batches
        metrics[epoch] = {"free_energy_loss": float(avg_loss)}
        print(f"Epoch [{epoch+1}/{num_epochs}] - Free Energy Loss: {avg_loss:.4f}")

    return metrics, state


In [8]:
batch_size      = 128
visible_units   = 28*28 # 784
hidden_units    = 256
k               = 1      # Gibbs steps for PCD
lr              = 1e-3
num_epochs      = 40
pcd_reset       = 75     # Reset persistent chain every N batches
weight_decay    = 1e-5   # L2 regularization
lr_decay_rate   = 0.95   # Learning rate decay factor PER EPOCH
temperature     = 1.0    # Sampling temperature (fixed)



rbm = RBM(n_visible=visible_units, n_hidden=hidden_units)
metrics, state = train_rbm_jax(
    rbm,
    x_train,
    batch_size=batch_size,
    num_epochs=num_epochs,
    k=k,
    lr=lr,
    weight_decay=weight_decay,
    lr_decay=lr_decay_rate,
    pcd_reset=pcd_reset
)

AttributeError: 'function' object has no attribute 'init'