In [152]:
# Import all the necessary libraries

import torchvision.transforms as transforms
import torch
import torchvision
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data.dataloader import DataLoader
import numpy as np

In [154]:
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 1e-3

In [155]:
def one_hot_encode(label):
    """
        This function one-hot encodes the target label.
        E.g.:
            label = 3
            output = [0 0 0 1 0 0 0 0 0 0]
    """
    one_hot_encoder = OneHotEncoder(sparse_output=False)
    one_hot_encoder.fit(np.arange(10).reshape(-1, 1))
    return one_hot_encoder.transform(np.array(label).reshape(1, 1)).reshape(-1)

# transforms
# This normalises the data and then flattens the 28x28 image to a 784 dimensional vector
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.Lambda(lambda x: torch.flatten(x))]
)


# Load the MNIST train data
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=transform,
    target_transform=one_hot_encode,
)

# Load the MNIST test data
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=transform,
    target_transform=one_hot_encode,
)

# Create the train and testloaders
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

In [156]:
class NeuralNetwork(eqx.Module):
    layers: list
    
    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        # A simple neural network with 3 hidden layers
        self.layers = [
            eqx.nn.Linear(784, 64, key=key1),
            eqx.nn.Linear(64, 32, key=key2),
            eqx.nn.Linear(32, 10, key=key3)
        ]
    
    def __call__(self, x):
        for layer in self.layers:
            # call relu on each layer
            x = jax.nn.relu(layer(x))
        return x

In [157]:
@jax.jit
def loss(model, x, y):
    # Our input has the shape (BATCH_SIZE x 784) but our model expects an input of (784, ) 
    # Therefore, we have to use jax.vmap, which - by default - maps over the 0th axis (i.e. the batch in our example)
    pred_y = jax.vmap(model)(x)
    # mean squared error loss
    return jnp.mean((y - pred_y) ** 2)

In [158]:
# Init our model
key, subkey = jax.random.split(jax.random.PRNGKey(32))
model = NeuralNetwork(subkey)

In [159]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()

dummy_x.shape

(64, 784)

In [160]:
# Example inference
output = jax.vmap(model)(dummy_x)
output.shape

(64, 10)

In [162]:
# Getting the initial loss value
value, grads = jax.value_and_grad(loss)(model, dummy_x, dummy_y)
print(value)

0.09782601


In [163]:
def evaluate(model, testloader: DataLoader):
    """
        This function evaluates the model on the testing dataset by computing the average loss and also the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        avg_loss += loss(model, x, y)
        preds = jax.vmap(model)(x)
        targets = jnp.argmax(y, axis=1)
        preds = jnp.argmax(preds, axis=1)
        avg_acc += sum(targets == preds) / len(x)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [164]:
evaluate(model, testloader)

(Array(0.0979604, dtype=float32),
 Array(0.11445063, dtype=float32, weak_type=True))

In [166]:
# Use Optax to optimise the parameters
optim = optax.adamw(LEARNING_RATE)

In [167]:
def fit(model, trainloader: DataLoader, testloader: DataLoader, optim):
    # Get the optimiser state
    opt_state = optim.init(model)
    
    @jax.jit
    def step(model, opt_state, x, y):
        loss_value, grads = jax.value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = optax.apply_updates(model, updates)
        return model, opt_state, loss_value
    
    for i, (x, y) in enumerate(trainloader):
        # PyTorch dataloaders give torch tensors by default, so we have to convert them to Numpy arrays first
        x = x.numpy()
        y = y.numpy()
        model, opt_state, loss_value = step(model, opt_state, x, y)
        # Every 10% of the dataset, print a snapshot of the model's performance
        if i % (len(trainloader) // 10) == 0:
            eval_loss, acc = evaluate(model, testloader)
            print(f"step={jnp.round((i / len(trainloader)) * 100) }%, {loss_value=}, {eval_loss=}, {acc=}")
    return model

In [168]:
# Train the model
model = fit(model, trainloader, testloader, optim)

step=0.0%, loss_value=Array(0.1006145, dtype=float32), eval_loss=Array(0.09500747, dtype=float32), acc=Array(0.09922373, dtype=float32, weak_type=True)
step=10.0%, loss_value=Array(0.02966868, dtype=float32), eval_loss=Array(0.03291377, dtype=float32), acc=Array(0.7995621, dtype=float32, weak_type=True)
step=20.0%, loss_value=Array(0.03069584, dtype=float32), eval_loss=Array(0.02885927, dtype=float32), acc=Array(0.81041, dtype=float32, weak_type=True)
step=30.0%, loss_value=Array(0.01746745, dtype=float32), eval_loss=Array(0.02669366, dtype=float32), acc=Array(0.82573646, dtype=float32, weak_type=True)
step=40.0%, loss_value=Array(0.03112168, dtype=float32), eval_loss=Array(0.02632064, dtype=float32), acc=Array(0.8225517, dtype=float32, weak_type=True)
step=50.0%, loss_value=Array(0.02785086, dtype=float32), eval_loss=Array(0.02399013, dtype=float32), acc=Array(0.83588773, dtype=float32, weak_type=True)
step=59.0%, loss_value=Array(0.02864545, dtype=float32), eval_loss=Array(0.02533414