In [1]:
import numpy as np

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state

from load_mnist import download_mnist_if_needed, load_images, load_labels

import matplotlib.pyplot as plt


data_dir = "./data"
device = jax.devices('cpu')[0]

print(f"Data resides in        : {data_dir}")
print(f"Training model on      : {str(device)}")

Data resides in        : ./data
Training model on      : TFRT_CPU_0


In [2]:
def print_samples(samples, elements_per_row=10, fig_width=10, cmap="binary"):
    num_digits = len(samples)
    num_rows = (num_digits + elements_per_row - 1) // elements_per_row

    plt.figure(figsize=(fig_width, fig_width / elements_per_row * num_rows))
    for idx, (label, image) in enumerate(samples):
        plt.subplot(num_rows, elements_per_row, idx + 1)
        plt.imshow(image.squeeze(), cmap=cmap)
        plt.title(label, fontsize=12)
        plt.axis('off')

    plt.tight_layout()
    plt.show()


def preprocess(x):
    x = x.astype(np.float32) / 255.0    # normalize to [0, 1]
    x = x > 0.5                         # binarize
    x = x.reshape(x.shape[0], -1)       # flatten for RBM
    x = jnp.array(x, dtype=jnp.float32) # use jax numpy array of dtype float32, because RBM has float32 params
    return x


data_paths = download_mnist_if_needed(root=data_dir, train_only=True)
x_train_raw = load_images(data_paths['train_images'])
y_train = load_labels(data_paths['train_labels'])

x_train = preprocess(x_train_raw)
print(f"x_train dtype: {x_train.dtype}, shape: {x_train.shape}")
print(f"x_train min: {x_train.min()}, max: {x_train.max()}")

x_train dtype: float32, shape: (60000, 784)
x_train min: 0.0, max: 1.0


In [8]:
# === MODIFIED RBM CLASS ===
class RBM(nn.Module):
    n_visible: int
    n_hidden: int

    def setup(self):
        # Using the initializer from the original request
        w_init = nn.initializers.normal(0.01)
        self.W = self.param("W", w_init, (self.n_visible, self.n_hidden))
        self.b = self.param("b", nn.initializers.zeros, (self.n_visible,))
        self.c = self.param("c", nn.initializers.zeros, (self.n_hidden,))

    def _sample_hidden(self, v, T=1.0):
        # *** Requires 'sample' RNG ***
        key = self.make_rng("sample")
        logits = (v @ self.W + self.c) / T
        h_probs = jax.nn.sigmoid(logits)
        h_sample = jax.random.bernoulli(key, h_probs)
        # Ensure float32 output for consistency in calculations
        return h_sample.astype(jnp.float32), h_probs

    def _sample_visible(self, h, T=1.0):
        # *** Requires 'sample' RNG ***
        key = self.make_rng("sample")
        logits = (h @ self.W.T + self.b) / T
        v_probs = jax.nn.sigmoid(logits)
        v_sample = jax.random.bernoulli(key, v_probs)
        # Ensure float32 output
        return v_sample.astype(jnp.float32), v_probs

    def sample_gibbs(self, v0_sample, k=1, T=1.0):
        # This method relies on _sample_hidden/_sample_visible using make_rng.
        # The caller (model.apply) must provide the 'sample' RNG.
        v = v0_sample
        for _ in range(k):
            # The internal calls get the RNG via make_rng from the apply context
            h, _ = self._sample_hidden(v, T)
            v, _ = self._sample_visible(h, T)
        return v

    def free_energy(self, v):
        # Calculate free energy (per sample) - does not require RNG
        # Ensure v has a batch dimension
        if v.ndim == 1:
            v = v[None, :] # Add batch dimension if single sample
        visible_term = jnp.dot(v, self.b) # shape: (batch,)
        hidden_logits = v @ self.W + self.c # shape: (batch, n_hidden)
        # Use jax.nn.softplus for numerical stability: log(1 + exp(x))
        hidden_term = jnp.sum(jax.nn.softplus(hidden_logits), axis=-1) # shape: (batch,)
        # Return free energy per sample
        return -visible_term - hidden_term # shape: (batch,)

    def __call__(self, v):
        # Flax modules require __call__.
        # Returning hidden probabilities requires the 'sample' RNG for _sample_hidden.
        # If you only ever call specific methods like free_energy or sample_gibbs
        # via model.apply(..., method=...), this __call__ might not be used directly.
        # Let's make it return hidden probabilities as a potentially useful default.
        # *** If called directly, it needs the 'sample' RNG supplied in rngs={'sample': key} ***
        _, h_probs = self._sample_hidden(v)
        return h_probs
        # Alternative simple __call__ that doesn't need RNG:
        # return v

In [20]:
import jax
import jax.numpy as jnp
import jax.lax # Make sure lax is imported
import flax.linen as nn

class RBM(nn.Module):
    n_visible: int
    n_hidden: int

    def setup(self):
        w_init = nn.initializers.normal(0.01)
        self.W = self.param("W", w_init, (self.n_visible, self.n_hidden))
        self.b = self.param("b", nn.initializers.zeros, (self.n_visible,))
        self.c = self.param("c", nn.initializers.zeros, (self.n_hidden,))

    # === MODIFIED: Accept key explicitly ===
    def _sample_hidden(self, key: jax.random.PRNGKey, v, T=1.0):
        logits = (v @ self.W + self.c) / T
        h_probs = jax.nn.sigmoid(logits)
        h_sample = jax.random.bernoulli(key, h_probs)
        return h_sample.astype(jnp.float32), h_probs

    # === MODIFIED: Accept key explicitly ===
    def _sample_visible(self, key: jax.random.PRNGKey, h, T=1.0):
        logits = (h @ self.W.T + self.b) / T
        v_probs = jax.nn.sigmoid(logits)
        v_sample = jax.random.bernoulli(key, v_probs)
        return v_sample.astype(jnp.float32), v_probs

    # === MODIFIED: Explicit RNG key handling in fori_loop ===
    def sample_gibbs(self, v0_sample, k=1, T=1.0):
        """Performs k steps of Gibbs sampling with explicit RNG handling, JIT compatible."""

        # Get the initial RNG key from the 'sample' stream provided by model.apply
        # This is the ONLY place we use make_rng now for this method's logic.
        loop_key = self.make_rng("sample")

        def gibbs_step_body(i, carry):
            v_carry, key_carry = carry
            # Split key for this iteration
            key_carry, hidden_key, visible_key = jax.random.split(key_carry, 3)

            # Pass keys explicitly
            h, _ = self._sample_hidden(hidden_key, v_carry, T)
            v_next, _ = self._sample_visible(visible_key, h, T)

            # Return updated state (v_next) and the *remaining* part of the key
            return (v_next, key_carry)

        # Initial carry for the loop: (initial visible state, initial loop key)
        initial_carry = (v0_sample, loop_key)

        # Run the loop
        final_carry = jax.lax.fori_loop(0, k, gibbs_step_body, initial_carry)

        # Return only the final visible state from the carry
        final_v = final_carry[0]
        return final_v
    # === END MODIFICATION ===

    def free_energy(self, v):
        if v.ndim == 1:
            v = v[None, :]
        visible_term = jnp.dot(v, self.b)
        hidden_logits = v @ self.W + self.c
        hidden_term = jnp.sum(jax.nn.softplus(hidden_logits), axis=-1)
        return -visible_term - hidden_term

    # === MODIFIED: Needs key if called directly, uses _sample_hidden ===
    def __call__(self, v):
        # If __call__ is used, it now also needs the 'sample' RNG key
        # to pass to _sample_hidden.
        key = self.make_rng("sample")
        h_sample, h_probs = self._sample_hidden(key, v) # Pass key explicitly
        return h_probs # Returning probs might be more useful than samples here
        # Or, if you don't want __call__ to require RNG:
        # return v # Simplest option

In [21]:
# === TRAINING CODE ===

# Define TrainState - simple wrapper around flax's TrainState
class TrainState(train_state.TrainState):
    # No additional fields needed for this basic setup
    pass

# Define the loss function (Contrastive Divergence using PCD)
def contrastive_divergence_loss(params, model_apply_fn, v_data, v_fantasy, k, rng_key, T=1.0):
    """Calculates the PCD loss and the next fantasy particles."""
    # 1. Generate k-step Gibbs samples starting from fantasy particles
    #    Pass the specific model method and the required 'sample' RNG
    v_k = model_apply_fn(
        {'params': params},
        v_fantasy,
        k=k,
        T=T,
        method=RBM.sample_gibbs, # Use the specific method
        rngs={'sample': rng_key} # Provide the RNG key for sampling
    )

    # 2. Calculate free energies (no RNG needed for free_energy)
    fe_data = model_apply_fn({'params': params}, v_data, method=RBM.free_energy)
    fe_fantasy_k = model_apply_fn({'params': params}, v_k, method=RBM.free_energy)

    # 3. Calculate the CD loss (mean over batch)
    loss = jnp.mean(fe_data) - jnp.mean(fe_fantasy_k)

    # 4. Return loss and the *next* fantasy particles (stop gradient flow)
    return loss, jax.lax.stop_gradient(v_k)

# Define the JIT-compiled training step
@jax.jit
def train_step(state, batch, fantasy_particles, k, rng_key, pcd_reset_key, batch_idx, pcd_reset_interval, T=1.0):
    """Performs a single training step, JIT-compiled."""

    # 1. Handle Persistent Contrastive Divergence (PCD) reset conditionally
    def reset_particles(_): # Function if reset is needed
        # Use pcd_reset_key for deterministic reset based on key
        return jax.random.bernoulli(pcd_reset_key, p=0.5, shape=fantasy_particles.shape).astype(jnp.float32)

    def keep_particles(fp): # Function if no reset needed
        return fp

    # Conditionally select whether to reset particles based on batch index
    current_fantasy_particles = jax.lax.cond(
        batch_idx % pcd_reset_interval == 0,
        reset_particles, # fun_true: uses pcd_reset_key
        keep_particles,  # fun_false: uses existing fantasy_particles
        fantasy_particles # operand passed to the selected function
    )

    # 2. Split RNG key for use in this step and for the next step
    step_rng, next_rng = jax.random.split(rng_key)
    # Further split step_rng for Gibbs sampling within the loss function
    gibbs_rng = step_rng # Use the main step_rng for Gibbs

    # 3. Define the loss function for JAX's grad function
    #    Captures variables like k, T, model apply_fn from the outer scope.
    def loss_fn(params):
        # Calculate loss and get the next fantasy particles as auxiliary output
        loss, next_fantasy = contrastive_divergence_loss(
            params,
            state.apply_fn, # The model's apply function
            batch,
            current_fantasy_particles,
            k,
            gibbs_rng, # Pass the specific RNG key for Gibbs
            T
        )
        return loss, next_fantasy

    # 4. Compute gradients, loss, and auxiliary output (next fantasy particles)
    (loss, next_fantasy_particles), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)

    # 5. Apply gradient updates to the state (includes optimizer step)
    state = state.apply_gradients(grads=grads)

    # 6. Prepare metrics for logging
    metrics = {'loss': loss}

    # 7. Return updated state, new fantasy particles, the RNG key for the *next* step, and metrics
    return state, next_fantasy_particles, next_rng, metrics

In [22]:
import time

#### TRAINING SETUP ####

# Hyperparameters (matching PyTorch setup)
batch_size      = 128
visible_units   = 28*28 # 784
hidden_units    = 256
k               = 1      # Gibbs steps for PCD
lr              = 1e-3
num_epochs      = 40
pcd_reset       = 75     # Reset persistent chain every N batches
weight_decay    = 1e-5   # L2 regularization
lr_decay_rate   = 0.95   # Learning rate decay factor PER EPOCH
temperature     = 1.0    # Sampling temperature (fixed)

# Calculate steps per epoch and total steps
steps_per_epoch = len(x_train) // batch_size
total_steps     = num_epochs * steps_per_epoch
print(f"Total training steps: {total_steps}")

# RNG Key Management
seed = 0
key = jax.random.PRNGKey(seed)
# Split key for model initialization, training loop, data shuffling, and PCD init
model_key, train_key, data_key, pcd_init_key = jax.random.split(key, 4)

# Initialize RBM Model
rbm_model = RBM(n_visible=visible_units, n_hidden=hidden_units)
# Create dummy input for parameter initialization
dummy_input = jnp.zeros((1, visible_units))

# --- CORRECTED INITIALIZATION (AGAIN) ---
# Provide RNG keys for BOTH 'params' (for initializers) and 'sample' (if __call__ uses it)
params = rbm_model.init({'params': model_key, 'sample': model_key}, dummy_input)['params']

# Create Optimizer with Learning Rate Schedule and Weight Decay
# Exponential decay schedule applied once per epoch
lr_schedule = optax.exponential_decay(
    init_value=lr,
    transition_steps=steps_per_epoch, # Number of steps before decay applied
    decay_rate=lr_decay_rate,
    staircase=True # Apply decay discretely at the transition steps
)

# AdamW optimizer (handles weight decay correctly compared to standard Adam)
optimizer = optax.adamw(learning_rate=lr_schedule, weight_decay=weight_decay)

# Create the Training State
tx_state = TrainState.create(apply_fn=rbm_model.apply, params=params, tx=optimizer)

# Initialize Fantasy Particles (Persistent Chain for PCD)
# Use a dedicated RNG key for initial particles
fantasy_particles = jax.random.bernoulli(pcd_init_key, p=0.5, shape=(batch_size, visible_units)).astype(jnp.float32)

# Simple NumPy/JAX Data Loader Function
def numpy_loader(rng_key, data, batch_size, drop_last=True):
    """Yields batches of data from a numpy/JAX array with shuffling."""
    n_data = data.shape[0]
    if drop_last:
        n_batches = n_data // batch_size
    else:
        n_batches = -(-n_data // batch_size) # Ceiling division

    # Generate a random permutation for shuffling
    perm = jax.random.permutation(rng_key, n_data)
    # Ensure indices stay within bounds after potential drop_last
    perm = perm[:n_batches * batch_size]
    perm = perm.reshape((n_batches, batch_size)) # Reshape for easy batch slicing

    for indices in perm:
        yield data[indices]


#### TRAINING LOOP ####

print("Starting RBM training...")
metrics_history = {'loss': []}

# Loop over epochs
for epoch in range(num_epochs):
    epoch_start_time = time.time()
    total_epoch_loss = 0.0

    # Get a new key for data shuffling this epoch
    data_key, epoch_data_key = jax.random.split(data_key)
    # Create batch generator for the current epoch
    batch_generator = numpy_loader(epoch_data_key, x_train, batch_size, drop_last=True)

    # Use tqdm for a progress bar over batches
    pbar = tqdm(enumerate(batch_generator), total=steps_per_epoch, desc=f"Epoch {epoch+1}/{num_epochs}")

    # Loop over batches in the current epoch
    for batch_idx, batch in pbar:
        # Get RNG keys for the current step: one for PCD reset, one for the main step
        train_key, pcd_reset_key, step_key = jax.random.split(train_key, 3)

        # Execute one training step
        tx_state, fantasy_particles, train_key, metrics = train_step(
            tx_state,           # Current training state (params, opt_state)
            batch,              # Current batch of data
            fantasy_particles,  # Current fantasy particles
            k,                  # Number of Gibbs steps
            step_key,           # RNG key for this training step (used for Gibbs)
            pcd_reset_key,      # RNG key specifically for potential PCD reset
            batch_idx,          # Current batch index (for PCD reset condition)
            pcd_reset,          # How often to reset PCD chain
            temperature         # Sampling temperature
        )

        # Accumulate loss and update progress bar
        batch_loss = metrics['loss']
        total_epoch_loss += batch_loss
        pbar.set_postfix(loss=f"{batch_loss:.4f}")

    # Calculate average loss for the epoch
    avg_epoch_loss = total_epoch_loss / steps_per_epoch
    metrics_history['loss'].append(avg_epoch_loss)
    epoch_time = time.time() - epoch_start_time

    # Print epoch summary, including current learning rate
    current_step = (epoch + 1) * steps_per_epoch
    current_lr = lr_schedule(tx_state.step) # Optax schedules use step count from state
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {avg_epoch_loss:.4f}, Current LR: {current_lr:.6f}, Time: {epoch_time:.2f}s")


print("Training finished.")

Total training steps: 18720
Starting RBM training...


Epoch 1/40: 100%|██████████| 468/468 [00:02<00:00, 156.95it/s, loss=-458.3484]


Epoch [1/40] - Avg Loss: -80.2311, Current LR: 0.000950, Time: 2.99s


Epoch 2/40: 100%|██████████| 468/468 [00:02<00:00, 203.18it/s, loss=-649.1641]


Epoch [2/40] - Avg Loss: -152.3282, Current LR: 0.000903, Time: 2.30s


Epoch 3/40: 100%|██████████| 468/468 [00:02<00:00, 202.92it/s, loss=-816.8418]


Epoch [3/40] - Avg Loss: -218.6608, Current LR: 0.000857, Time: 2.31s


Epoch 4/40: 100%|██████████| 468/468 [00:02<00:00, 195.38it/s, loss=-919.9186]


Epoch [4/40] - Avg Loss: -235.0880, Current LR: 0.000815, Time: 2.40s


Epoch 5/40: 100%|██████████| 468/468 [00:02<00:00, 214.36it/s, loss=-717.7847]


Epoch [5/40] - Avg Loss: -192.9754, Current LR: 0.000774, Time: 2.18s


Epoch 6/40: 100%|██████████| 468/468 [00:02<00:00, 192.43it/s, loss=-430.1732]


Epoch [6/40] - Avg Loss: -159.3421, Current LR: 0.000735, Time: 2.43s


Epoch 7/40: 100%|██████████| 468/468 [00:02<00:00, 193.23it/s, loss=-289.3395]


Epoch [7/40] - Avg Loss: -129.5526, Current LR: 0.000698, Time: 2.42s


Epoch 8/40: 100%|██████████| 468/468 [00:02<00:00, 215.22it/s, loss=-133.3070]


Epoch [8/40] - Avg Loss: -104.6099, Current LR: 0.000663, Time: 2.18s


Epoch 9/40: 100%|██████████| 468/468 [00:02<00:00, 205.87it/s, loss=-111.0704]


Epoch [9/40] - Avg Loss: -83.7010, Current LR: 0.000630, Time: 2.27s


Epoch 10/40: 100%|██████████| 468/468 [00:02<00:00, 207.70it/s, loss=-54.9473] 


Epoch [10/40] - Avg Loss: -71.3379, Current LR: 0.000599, Time: 2.25s


Epoch 11/40:  19%|█▉        | 90/468 [00:00<00:02, 184.02it/s, loss=30.4053]  

In [None]:
import matplotlib.pyplot as plt

# --- Plotting Training Loss ---

# Ensure the metrics_history dictionary and the 'loss' list exist
# and contain the average loss values from each epoch of training.
# Example structure: metrics_history = {'loss': [epoch1_loss, epoch2_loss, ...]}

if 'metrics_history' in locals() and 'loss' in metrics_history and metrics_history['loss']:
    num_epochs_trained = len(metrics_history['loss'])
    epochs = range(1, num_epochs_trained + 1) # Epoch numbers for the x-axis (starting from 1)

    plt.figure(figsize=(10, 6)) # Create a figure to draw on
    plt.plot(epochs, metrics_history['loss'], marker='o', linestyle='-') # Plot epochs vs loss

    # Add labels and title for clarity
    plt.xlabel("Epoch")
    plt.ylabel("Average Loss (Free Energy Difference)")
    plt.title("RBM Training Loss per Epoch")
    plt.grid(True) # Add a grid

    # Optional: Adjust x-axis ticks for better readability if many epochs
    if num_epochs_trained > 10:
        plt.xticks(range(0, num_epochs_trained + 1, max(1, num_epochs_trained // 10)))

    plt.tight_layout() # Adjust layout
    plt.show() # Display the generated plot

else:
    print("Unable to plot: 'metrics_history' dictionary or 'loss' list not found or empty.")
    print("Please ensure the training loop ran correctly and populated 'metrics_history['loss']'.")