# Discriminative PC on MNIST

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/thebuckleylab/jpc/blob/main/examples/discriminative_pc.ipynb)

This notebook demonstrates how to train a neural network with predictive coding (PC) to discriminate or classify MNIST digits.

In [1]:
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1

In [2]:
import jpc

import jax
import equinox as eqx
import equinox.nn as nn
import optax

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

## Hyperparameters

We define some global parameters, including network architecture, learning rate, batch size etc.

In [3]:
SEED = 0

LAYER_SIZES = [784, 300, 300, 10]
ACT_FN = "relu"

LEARNING_RATE = 1e-3
BATCH_SIZE = 64
N_INFER_ITERS = 20
TEST_EVERY = 100
N_TRAIN_ITERS = 300

## Dataset

Some utils to fetch MNIST.

In [4]:
#@title data utils


def get_mnist_loaders(batch_size):
    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]
    

## Network

For `jpc` to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like `nn.Sequential()` in [Equinox](https://github.com/patrick-kidger/equinox). For example, we can define a ReLU MLP with two hidden layers as follows

In [5]:
key = jax.random.PRNGKey(SEED)
_, *subkeys = jax.random.split(key, 4)
network = [
    nn.Sequential(
        [
            nn.Linear(784, 300, key=subkeys[0]),
            nn.Lambda(jax.nn.relu)
        ],
    ),
    nn.Sequential(
        [
            nn.Linear(300, 300, key=subkeys[1]),
            nn.Lambda(jax.nn.relu)
        ],
    ),
    nn.Linear(300, 10, key=subkeys[2]),
]
print(network)

[Sequential(
  layers=(
    Linear(
      weight=f32[300,784],
      bias=f32[300],
      in_features=784,
      out_features=300,
      use_bias=True
    ),
    Lambda(fn=<wrapped function relu>)
  )
), Sequential(
  layers=(
    Linear(
      weight=f32[300,300],
      bias=f32[300],
      in_features=300,
      out_features=300,
      use_bias=True
    ),
    Lambda(fn=<wrapped function relu>)
  )
), Linear(
  weight=f32[10,300],
  bias=f32[10],
  in_features=300,
  out_features=10,
  use_bias=True
)]


You can also use the utility `jpc.get_fc_network` to define an MLP or fully connected network with some activation functions.

In [6]:
network = jpc.get_fc_network(key, LAYER_SIZES, act_fn="relu")
print(network)

[Sequential(
  layers=(
    Linear(
      weight=f32[300,784],
      bias=f32[300],
      in_features=784,
      out_features=300,
      use_bias=True
    ),
    Lambda(fn=<wrapped function relu>)
  )
), Sequential(
  layers=(
    Linear(
      weight=f32[300,300],
      bias=f32[300],
      in_features=300,
      out_features=300,
      use_bias=True
    ),
    Lambda(fn=<wrapped function relu>)
  )
), Linear(
  weight=f32[10,300],
  bias=f32[10],
  in_features=300,
  out_features=10,
  use_bias=True
)]


## Train and test

A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy. Note that these functions are already "jitted" for performance.

Below we simply wrap each of these functions in our training and test loops, respectively.

In [7]:
def evaluate(model, test_loader):
    test_acc = 0
    for batch_id, (img_batch, label_batch) in enumerate(test_loader):
        img_batch = img_batch.numpy()
        label_batch = label_batch.numpy()

        test_acc += jpc.test_discriminative_pc(
            model=model,
            y=label_batch,
            x=img_batch
        )

    return test_acc / len(test_loader)


def train(
      model,
      lr,
      batch_size,
      n_infer_iters,
      test_every,
      n_train_iters
):
    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))
    train_loader, test_loader = get_mnist_loaders(batch_size)

    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch = img_batch.numpy()
        label_batch = label_batch.numpy()

        result = jpc.make_pc_step(
            model,
            optim,
            opt_state,
            y=label_batch,
            x=img_batch,
            n_iters=n_infer_iters
        )
        model, optim, opt_state = result["model"], result["optim"], result["opt_state"]
        train_loss = result["loss"]
        if ((iter+1) % test_every) == 0:
            avg_test_acc = evaluate(model, test_loader)
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break

    return result


## Run

In [8]:
result = train(
    model=network,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    n_infer_iters=N_INFER_ITERS,
    test_every=TEST_EVERY,
    n_train_iters=N_TRAIN_ITERS
)



Train iter 100, train loss=0.010833, avg test accuracy=0.938001
Train iter 200, train loss=0.013239, avg test accuracy=0.955529
Train iter 300, train loss=0.013557, avg test accuracy=0.957933
