# MNIST Classification with blox

This tutorial demonstrates how to train a Convolutional Neural Network (CNN) on MNIST using **blox**.

We will strictly follow a probabilistic approach:
1.  **Model**: Defines a conditional distribution $P(Y | X)$.
2.  **Objective**: Maximize the likelihood of the data (minimize Negative Log Likelihood).

We use **Distrax** to handle the probability distributions.

In [1]:
# Install necessary packages.
!pip install -q jax-blox optax tensorflow tensorflow-datasets matplotlib distrax

In [None]:
import blox as bx
import distrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
batch_size = 32
train_steps = 1200
eval_every = 200


def get_datasets(batch_size: int, train_steps: int):
  """Load MNIST datasets as tf.data Datasets."""

  def normalize(sample):
    image = tf.cast(sample['image'], tf.float32) / 255.0
    return {'image': image, 'label': sample['label']}

  train_ds = tfds.load('mnist', split='train')
  train_ds = train_ds.map(normalize)
  train_ds = train_ds.repeat().shuffle(1024)
  train_ds = train_ds.batch(batch_size, drop_remainder=True)
  train_ds = train_ds.take(train_steps).prefetch(1)

  test_ds = tfds.load('mnist', split='test')
  test_ds = test_ds.map(normalize)
  test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)

  return train_ds, test_ds


train_ds, test_ds = get_datasets(batch_size, train_steps)
sample_batch = next(iter(train_ds.as_numpy_iterator()))
print(f'Batch shape: {sample_batch["image"].shape}')

In [None]:
class CNN(bx.Module):
  """A probabilistic CNN classifier."""

  def __init__(self, graph: bx.Graph, num_classes: int = 10):
    super().__init__(graph)
    self.conv1 = bx.Conv(
        graph.child('conv1'), output_channels=32, kernel_size=3
    )
    self.conv2 = bx.Conv(
        graph.child('conv2'), output_channels=64, kernel_size=3
    )
    self.linear1 = bx.Linear(graph.child('linear1'), output_size=256)
    self.linear2 = bx.Linear(graph.child('linear2'), output_size=num_classes)
    self.dropout = bx.Dropout(graph.child('dropout'), rate=0.5)

  def __call__(self, params: bx.Params, x: jax.Array, is_training: bool = True):
    # Convolutional feature extraction.
    x, params = self.conv1(params, x)
    x = jax.nn.relu(x)
    x = bx.avg_pool(x, window_shape=2, strides=2)

    x, params = self.conv2(params, x)
    x = jax.nn.relu(x)
    x = bx.avg_pool(x, window_shape=2, strides=2)

    # Flatten and dense layers.
    x = x.reshape((x.shape[0], -1))
    x, params = self.linear1(params, x)
    x = jax.nn.relu(x)
    x, params = self.dropout(params, x, is_training=is_training)

    # Output logits for the Categorical distribution.
    logits, params = self.linear2(params, x)
    return logits, params

In [None]:
# Create model components.
graph = bx.Graph('mnist_cnn')
model = CNN(graph)
rng = bx.Rng(graph.child('rng'), seed=0)

# Initialize params with a sample batch item.
params = bx.Params(rng=rng)
_, params = model(params, sample_batch['image'][:1], is_training=False)
params = params.finalized()

# Create a separate eval key for evaluation.
eval_key = jax.random.key(1)

bx.display(model.graph, params)

In [None]:
@jax.jit
def train_step(params, opt_state, batch_images, batch_labels, optimizer):
  trainable, non_trainable = params.split()

  def loss_fn(trainable):
    params = trainable.merge(non_trainable)
    logits, new_params = model(params, batch_images, is_training=True)

    # Probabilistic Loss: Negative Log Likelihood.
    # We model the output as a Categorical distribution.
    dist = distrax.Categorical(logits=logits)
    loss = -dist.log_prob(batch_labels).mean()

    _, new_non_trainable = new_params.split()
    return loss, new_non_trainable

  (loss, new_non_trainable), grads = jax.grad(loss_fn, has_aux=True)(trainable)
  updates, new_opt_state = optimizer.update(grads, opt_state, trainable)
  new_trainable = optax.apply_updates(trainable, updates)

  return new_trainable.merge(new_non_trainable), new_opt_state, loss

In [None]:
@jax.jit
def eval_step(params, batch_images, batch_labels):
  logits, params = model(params, batch_images, is_training=False)
  # The mode of the distribution is the prediction.
  dist = distrax.Categorical(logits=logits)
  predicted_class = dist.mode()
  accuracy = jnp.mean(predicted_class == batch_labels)
  return accuracy, params

In [None]:
def train_model(params):
  optimizer = optax.adamw(1e-3)
  trainable_params, _ = params.split()
  opt_state = optimizer.init(trainable_params)

  # Track eval counter separately to avoid reusing RNG keys during evaluation.
  eval_counter = 0

  history = {'loss': [], 'val_acc': []}

  for step, batch in enumerate(train_ds.as_numpy_iterator()):
    params, opt_state, loss = train_step(
        params, opt_state, batch['image'], batch['label'], optimizer
    )
    history['loss'].append(float(loss))

    # Evaluate periodically.
    if (step + 1) % eval_every == 0:
      # Use separate eval key to avoid interfering with training RNG.
      eval_params = rng.set_base_key(params, eval_key)
      eval_params = rng.set_counter(eval_params, eval_counter)

      test_accs = []
      for test_batch in test_ds.as_numpy_iterator():
        acc, eval_params = eval_step(
            eval_params, test_batch['image'], test_batch['label']
        )
        test_accs.append(acc)

      # Save eval counter for next evaluation.
      eval_counter, _ = rng.get_counter(eval_params)

      test_acc = np.mean(test_accs)
      history['val_acc'].append(test_acc)

      print(f'Step {step + 1}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}')

  return params, history


trained_params, history = train_model(params)

In [None]:
# Visualize Training Progress.
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Training Loss')
plt.xlabel('Step')
plt.ylabel('Negative Log Likelihood')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(
    [eval_every * (i + 1) for i in range(len(history['val_acc']))],
    history['val_acc'],
    label='Test Accuracy',
    color='orange',
)
plt.xlabel('Step')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


# Visualize Predictions.
def show_predictions(params, count=5):
  test_batch = next(iter(test_ds.as_numpy_iterator()))
  images = test_batch['image'][:count]
  labels = test_batch['label'][:count]
  logits, _ = model(params, images, is_training=False)

  # Probabilistic prediction.
  dist = distrax.Categorical(logits=logits)
  preds = dist.mode()

  fig, axes = plt.subplots(1, count, figsize=(15, 3))
  for i, ax in enumerate(axes):
    ax.imshow(images[i].squeeze(), cmap='gray')
    ax.set_title(f'True: {labels[i]}, Pred: {preds[i]}')
    ax.axis('off')
  plt.show()


show_predictions(trained_params)