# A Simple Convolutional Neural Network on the MNIST Dataset

In this example, we will implement a CNN and this should teach you the following important concepts:

    - How to create a custom neural network
    - Why and when it makes sense to use exq.filter_[...] functions
    - How to train your network using Optax
    - What your Neural Network looks like "under the hood" (foreshadowing: like a PyTree!)

We also use [jaxtyping](https://github.com/google/jaxtyping) for type annotations. This makes it easier to collaborators to work on your code and also improves code quality while also reducing the probability of unexpected bugs. 

In [1]:
# Import all the necessary libraries

import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import torchvision
import torchvision.transforms as transforms
from jaxtyping import Array, PyTree
from optax import GradientTransformation
from torch.utils.data.dataloader import DataLoader

import equinox as eqx

In [2]:
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 1e-3

## Data

For this example, we will use the *regular* MNIST dataset. To load it up, we use PyTorch and its excellent DataLoader class. To use it freely, we will need to
    1. normalise the data
    2. one-hot encode the labels

The first step, we can perform using some PyTorch utilities, while for the second step, we can use Jax's `one_hot` function.

You can check the [PyTorch documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) for more information on the dataset.

In [3]:
def one_hot_encode(label):
    """
    This function one-hot encodes the target label.
    E.g.:
        label = 3
        output = [0 0 0 1 0 0 0 0 0 0]
    """
    return np.array(jax.nn.one_hot(label, 10))


# transforms
# This normalises the data
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)


# Load the MNIST train data
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=transform,
    target_transform=one_hot_encode,
)

# Load the MNIST test data
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=transform,
    target_transform=one_hot_encode,
)

# Create the train and testloaders
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

In [4]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x1x28x28
print(dummy_y.shape)  # 64x10

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(64, 1, 28, 28)
(64, 10)


We can see that our data has the shape 64x1x28x28. 64 is the batch size, 1 is the number of input channels (this will be important soon) and 28x28 is the image dimension.

## The Network

Our Convolutional Neural Network (CNN) will have a list that contains all our layers. There is no explicit requirement to do it that way, it's simply convinient for this example. Inbetween the layers (which in essence are just Jax arrays, i.e. **leaves** of a PyTree), we have some things which are **not** arrays! These functions, e.g. `relu`, are not things that Jax can differentiate! This means we will have to *filter* those things out. Luckily, Equinox provides several useful filter functions for this. 

In [5]:
class NeuralNetwork(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # A simple neural network with 3 hidden layers
        self.layers = [
            eqx.nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.softmax,
        ]

    def __call__(self, x):
        for layer in self.layers:
            # call relu on each layer
            x = layer(x)
        return x

In [6]:
def loss(model: PyTree, x: Array, y: Array):
    # Our input has the shape (BATCH_SIZE x 1 x 28 x 28) but our model expects an
    # input image with a single input channel;
    # the batch dimension was not defined in our model!
    # Therefore, we have to use jax.vmap, which - by default - maps over the 0th axis
    # (i.e. the batch in our example)
    pred_y = jax.vmap(model)(x)
    # mean squared error loss
    return jnp.mean((y - pred_y) ** 2)

In [7]:
# Init our model
key, subkey = jax.random.split(jax.random.PRNGKey(32))
model = NeuralNetwork(subkey)

### Our PyTree Model

We can print our model to see the PyTree that was generated by Equinox

In [8]:
print(model)

NeuralNetwork(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[3,1,4,4],
      bias=f32[3,1,1],
      in_channels=1,
      out_channels=3,
      kernel_size=(4, 4),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      groups=1,
      use_bias=True
    ),
    MaxPool2d(
      init=-inf,
      operation=<function max>,
      num_spatial_dims=2,
      kernel_size=(2, 2),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      use_ceil=False
    ),
    <wrapped function relu>,
    <wrapped function ravel>,
    Linear(
      weight=f32[512,1728],
      bias=f32[512],
      in_features=1728,
      out_features=512,
      use_bias=True
    ),
    <wrapped function sigmoid>,
    Linear(
      weight=f32[64,512],
      bias=f32[64],
      in_features=512,
      out_features=64,
      use_bias=True
    ),
    <wrapped function relu>,
    Linear(
      weight=f32[10,64],
      bias=f32[10],
      in_features=64,
      out_features=10,
      use_bi

If we wanted to perform inference on our model given some data, then we could do it as shown below.

(**Note** that you should typically avoid using Jax transformations (such as `vmap`) outside of `jit` as it can be very slow)

In [9]:
# Example inference
output = jax.vmap(model)(dummy_x)
output.shape

(64, 10)

### Filtering

In the next cells we can see an example of when we should use the filter methods provided by Equinox. For instance, the following code generates an error:

In [10]:
# This is an error!

jax.value_and_grad(loss)(model, dummy_x, dummy_y)

TypeError: Argument '<function max at 0x7f6d62c443a0>' of type <class 'function'> is not a valid JAX type.

And this makes sense, because when you look back up in the PyTree, you will see some lines that say e.g. `<wrapped function relu>,` . This means there is a function in our PyTree. When Jax tries to compute the gradient, it arrives at this function and is confused (rightfully so!). Therefore, we have to filter out anything that is not a differentiable PyTree leaf! 

A general rule of thumb is: since `eqx.filter_[...]` is essentially the same as the corresponding Jax functions, whenever you're unsure which one to use, you can simply fall back to the Equinox filter functions.

In [11]:
# This will work!

# Getting the initial loss value
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
print(value)

0.09025362


As with most machine learning tasks, we need some methods to evaluate our model on some testdata. For this we create the following two functions. 

Notice, that we used `eqx.filter_jit` instead of `jax.jit` since our PyTree (i.e. our model) contains non-arrays, such as `relu` functions, which can't be interpreted as arrays.

In [12]:
# Exchange eqx.filter_jit with jax.jit to see what kind of error we get
@eqx.filter_jit
def compute_accuracy(model: PyTree, x: Array, y: Array):
    """
    This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    preds = jax.vmap(model)(x)
    targets = jnp.argmax(y, axis=1)
    preds = jnp.argmax(preds, axis=1)
    return sum(targets == preds) / len(x)

In [13]:
def evaluate(model: PyTree, testloader: DataLoader):
    """
    This function evaluates the model on the testing dataset
    by computing the average loss and also the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [14]:
evaluate(model, testloader)

(Array(0.09036118, dtype=float32),
 Array(0.09862659, dtype=float32, weak_type=True))

## Training the Model

Now it's time to train our model using Optax! The most important thing to notice here is the `eqx.apply_updates` function! This is a wrapper for the function `optax.apply_updates`. But why do we need it? The reason is that `eqx.filter_value_and_grad` filters our anything that is not differentiable, e.g. functions in our PyTree and instead replaces those fields with `None`. This means that when you compute the gradients and the updates (from `optim.updates(...)`) there will be places with `None`s. The *original* function from Optax would then try to update `None`s using something like `None - learning_rate * None`, which is a big *no no*!

Therefore we can use the utility function from Equinox, which has the identical functionality as `optax.apply_updates` only that it skips the `None`s. 

In [15]:
# Use Optax to optimise the parameters
optim = optax.adamw(LEARNING_RATE)

In [16]:
def fit(
    model: PyTree,
    trainloader: DataLoader,
    testloader: DataLoader,
    optim: GradientTransformation,
):
    # Get the optimiser state
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def step(model: PyTree, opt_state: PyTree, x: Array, y: Array):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        # We use eqx.apply_updates because our gradients contain Nones,
        # which Optax can't handle by default
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    for i, (x, y) in enumerate(trainloader):
        # PyTorch dataloaders give torch tensors by default,
        # so we have to convert them to Numpy arrays first
        x = x.numpy()
        y = y.numpy()
        model, opt_state, loss_value = step(model, opt_state, x, y)
        # Every 10% of the dataset, print a snapshot of the model's performance
        if i % (len(trainloader) // 10) == 0:
            eval_loss, acc = evaluate(model, testloader)
            current_percentage = np.round((i / len(trainloader)) * 100)
            print(f"step={current_percentage}%, {loss_value=}, {eval_loss=}, {acc=}")
    return model

In [17]:
# Train the model
model = fit(model, trainloader, testloader, optim)

step=0.0%, loss_value=Array(0.0899663, dtype=float32), eval_loss=Array(0.09051719, dtype=float32), acc=Array(0.10360271, dtype=float32, weak_type=True)
step=10.0%, loss_value=Array(0.01999055, dtype=float32), eval_loss=Array(0.0159752, dtype=float32), acc=Array(0.9043591, dtype=float32, weak_type=True)
step=20.0%, loss_value=Array(0.0082739, dtype=float32), eval_loss=Array(0.01234425, dtype=float32), acc=Array(0.91988456, dtype=float32, weak_type=True)
step=30.0%, loss_value=Array(0.0094656, dtype=float32), eval_loss=Array(0.00933974, dtype=float32), acc=Array(0.9371019, dtype=float32, weak_type=True)
step=40.0%, loss_value=Array(0.01203251, dtype=float32), eval_loss=Array(0.00743253, dtype=float32), acc=Array(0.951035, dtype=float32, weak_type=True)
step=50.0%, loss_value=Array(0.00362388, dtype=float32), eval_loss=Array(0.00623893, dtype=float32), acc=Array(0.95879775, dtype=float32, weak_type=True)
step=59.0%, loss_value=Array(0.00316525, dtype=float32), eval_loss=Array(0.00616679, 