# A Simple Convolutional Neural Network on the MNIST Dataset

In this example, we will implement a CNN and this should teach you the following important concepts:

    - How to create a custom neural network
    - Why and when it makes sense to use exq.filter_[...] functions
    - How to train your network using Optax
    - What your Neural Network looks like "under the hood" (foreshadowing: like a PyTree!)

In [None]:
# Import all the necessary libraries

import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import torchvision
import torchvision.transforms as transforms
from jaxtyping import PyTree
from torch.utils.data.dataloader import DataLoader

import equinox as eqx

In [None]:
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 1e-3

## Data

For this example, we will use the *regular* MNIST dataset. To load it up, we use PyTorch and its excellent DataLoader class. To use it freely, we will need to
    1. normalise the data
    2. one-hot encode the labels

The first step, we can perform using some PyTorch utilities, while for the second step, we can use Jax's `one_hot` function.

You can check the [PyTorch documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) for more information on the dataset.

In [None]:
def one_hot_encode(label):
    """
    This function one-hot encodes the target label.
    E.g.:
        label = 3
        output = [0 0 0 1 0 0 0 0 0 0]
    """
    return np.array(jax.nn.one_hot(label, 10))


# transforms
# This normalises the data
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)


# Load the MNIST train data
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=transform,
    target_transform=one_hot_encode,
)

# Load the MNIST test data
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=transform,
    target_transform=one_hot_encode,
)

# Create the train and testloaders
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

In [None]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x1x28x28
print(dummy_y.shape)  # 64x10

We can see that our data has the shape 64x1x28x28. 64 is the batch size, 1 is the number of input channels (this will be important soon) and 28x28 is the image dimension.

## The Network

Our Convolutional Neural Network (CNN) will have a list that contains all our layers. There is no explicit requirement to do it that way, it's simply convinient for this example. Inbetween the layers (which in essence are just Jax arrays, i.e. **leaves** of a PyTree), we have some things which are **not** arrays! These functions, e.g. `relu`, are not things that Jax can differentiate! This means we will have to *filter* those things out. Luckily, Equinox provides several useful filter functions for this. 

In [None]:
class NeuralNetwork(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # A simple neural network with 3 hidden layers
        self.layers = [
            eqx.nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.softmax,
        ]

    def __call__(self, x):
        for layer in self.layers:
            # call relu on each layer
            x = layer(x)
        return x

In [None]:
def loss(model, x, y):
    # Our input has the shape (BATCH_SIZE x 1 x 28 x 28) but our model expects an
    # input image with a single input channel;
    # the batch dimension was not defined in our model!
    # Therefore, we have to use jax.vmap, which - by default - maps over the 0th axis
    # (i.e. the batch in our example)
    pred_y = jax.vmap(model)(x)
    # mean squared error loss
    return jnp.mean((y - pred_y) ** 2)

In [None]:
# Init our model
key, subkey = jax.random.split(jax.random.PRNGKey(32))
model = NeuralNetwork(subkey)

### Our PyTree Model

We can check that our model really is just a PyTree by using the `isinstance` method:

In [None]:
isinstance(model, PyTree)

And when we *pretty print* our model, we can see what it looks like:

In [None]:
eqx.tree_pprint(model)

In [None]:
# Example inference
output = jax.vmap(model)(dummy_x)
output.shape

### Filtering

In the next cells we can see an example of when we should use the filter methods provided by Equinox. For instance, the following code generates an error:

In [None]:
# This is an error!

jax.value_and_grad(loss)(model, dummy_x, dummy_y)

And this makes sense, because when you look back up in the PyTree, you will see some lines that say e.g. `<wrapped function relu>,` . This means there is a function in our PyTree. When Jax tries to compute the gradient, it arrives at this function and is confused (rightfully so!). Therefore, we have to filter out anything that is not a differentiable PyTree leaf! 

A general rule of thumb is: since `eqx.filter_[...]` is essentially the same as the corresponding Jax functions, whenever you're unsure which one to use, you can simply fall back to the Equinox filter functions.

In [None]:
# This will work!

# Getting the initial loss value
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
print(value)

In [None]:
def evaluate(model, testloader: DataLoader):
    """
    This function evaluates the model on the testing dataset
    by computing the average loss and also the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    jitted_loss = eqx.filter_jit(loss)
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        avg_loss += jitted_loss(model, x, y)
        preds = jax.vmap(model)(x)
        targets = jnp.argmax(y, axis=1)
        preds = jnp.argmax(preds, axis=1)
        avg_acc += sum(targets == preds) / len(x)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [None]:
evaluate(model, testloader)

## Training the Model

Now it's time to train our model using Optax! The most important thing to notice here is the `eqx.apply_updates` function! This is a wrapper for the function `optax.apply_updates`. But why do we need it? The reason is that `eqx.filter_value_and_grad` filters our anything that is not differentiable, e.g. functions in our PyTree and instead replaces those fields with `None`. This means that when you compute the gradients and the updates (from `optim.updates(...)`) there will be places with `None`s. The *original* function from Optax would then try to update `None`s using something like `None - learning_rate * None`, which is a big *no no*!

Therefore we can use the utility function from Equinox, which has the identical functionality as `optax.apply_updates` only that it skips the `None`s. 

In [None]:
# Use Optax to optimise the parameters
optim = optax.adamw(LEARNING_RATE)

In [None]:
def fit(model, trainloader: DataLoader, testloader: DataLoader, optim):
    # Get the optimiser state
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def step(model, opt_state, x, y):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        # We use eqx.apply_updates because our gradients contain Nones,
        # which Optax can't handle by default
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    for i, (x, y) in enumerate(trainloader):
        # PyTorch dataloaders give torch tensors by default,
        # so we have to convert them to Numpy arrays first
        x = x.numpy()
        y = y.numpy()
        model, opt_state, loss_value = step(model, opt_state, x, y)
        # Every 10% of the dataset, print a snapshot of the model's performance
        if i % (len(trainloader) // 10) == 0:
            eval_loss, acc = evaluate(model, testloader)
            current_percentage = jnp.round((i / len(trainloader)) * 100)
            print(f"step={current_percentage}%, {loss_value=}, {eval_loss=}, {acc=}")
    return model

In [None]:
# Train the model
model = fit(model, trainloader, testloader, optim)