# MLP Implementation in mlax with Optax optimizers.
This notebook uses the [Optax](https://optax.readthedocs.io/en/latest/optax-101.html)
JAX optimization library.

You can view the Pytorch reference implementation in `mlp_reference.ipynb`.

In [1]:
import jax
from jax import (
    numpy as jnp,
    nn,
    random
)
import numpy as np
import optax
import torchvision
from torch.utils.data import DataLoader

In [2]:
from mlax.nn import Series, Linear, Bias, F

### Load in and batch the MNIST datasets.
We follow this example
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) in using Pytorch dataloaders.

In [3]:
class ToNumpy:
  def __call__(self, pic):
    return np.array(pic)

mnist_train = torchvision.datasets.MNIST(
    root="../data",
    train=True,
    download=True,
    transform=ToNumpy()
)
mnist_test = torchvision.datasets.MNIST(
    root="../data",
    train=False,
    download=True,
    transform=ToNumpy()
)
print(mnist_train.data.shape)
print(mnist_test.data.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])


In [4]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

batch_size=128
train_dataloader = DataLoader(
    mnist_train, batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=6
)
test_dataloader = DataLoader(
    mnist_test, batch_size, collate_fn=numpy_collate, num_workers=6
)
print(len(train_dataloader), len(test_dataloader))

469 79


## Build MLP using `mlax.module`

In [5]:
keys_iter = iter([random.fold_in(random.PRNGKey(0), i) for i in range(6)])
model = Series([
    F(lambda x: jnp.reshape(x.astype(jnp.float32) / 255.0, (-1,))),  # Flatten and scale
    Linear(next(keys_iter), out_features=512),
    Bias(next(keys_iter), in_features=512),
    F(nn.relu),
    Linear(next(keys_iter), out_features=512),
    Bias(next(keys_iter), in_features=512),
    F(nn.relu),
    Linear(next(keys_iter), out_features=10),
    Bias(next(keys_iter), in_features=10)
])

# Induce lazy initialization
for X, _ in train_dataloader:
    activations, _ = model(X[0], None, inference_mode=True)
    print(activations.shape)
    break

(10,)


### Define loss function.

In [6]:
def loss_fn(batched_preds, batched_targets):
    return optax.softmax_cross_entropy_with_integer_labels(
        batched_preds, batched_targets
    ).mean() # Optax returns per-example loss, this returns the mean batch loss

### Define optimizer using Optax.

In [7]:
optimizer = optax.sgd(1e-2, momentum=0.9)
optim_state = optimizer.init(model.filter())

### Define training and testing steps.

In [8]:
@jax.jit
def train_step(X, y, model, optim_state):
    def _model_loss(X, y, trainables, non_trainables):
        model = trainables.combine(non_trainables)
        preds, model = jax.vmap(
            model.__call__,
            in_axes = (0, None, None, None),
            out_axes = (0, None),
            axis_name = "N"
        )(X, None, False, "N")
        return loss_fn(preds, y), model

    # Find batch loss and gradients with resect to trainables
    trainables, non_trainables = model.partition()
    (loss, model), gradients = jax.value_and_grad(
        _model_loss,
        argnums=2, # gradients wrt trainables (argument 2 of model_training_loss)
        has_aux=True # model is auxiliary data, loss is the true ouput
    )(X, y, trainables, non_trainables)

    # Get new gradients and optimizer state
    gradients, optim_state = optimizer.update(gradients, optim_state)

    # Update parameters with new gradients
    trainables, non_trainables = model.partition()
    trainables = optax.apply_updates(gradients, trainables)
    return loss, trainables.combine(non_trainables), optim_state

In [9]:
@jax.jit
def test_step(X, y, model):
    preds, _ = jax.vmap(
        model.__call__,
        in_axes = (0, None, None, None),
        out_axes = (0, None),
        axis_name = "N"
    )(X, None, True, "N")
    accurate = (jnp.argmax(preds, axis=1) == y).sum()
    return loss_fn(preds, y), accurate

### Define training and testing loops.

In [10]:
def train_epoch(dataloader, model, optim_state):
    train_loss = 0.0
    for X, y in dataloader:
        loss, model, optim_state = train_step(X, y, model, optim_state)
        train_loss += loss

    print(f"Train loss: {train_loss / len(dataloader)}") 
    return model, optim_state

In [11]:
def test(dataloader, model):
    test_loss, accurate = 0.0, 0
    for X, y in dataloader:
        loss, acc = test_step(X, y, model)
        test_loss += loss
        accurate += acc
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [12]:
def train_loop(
    train_dataloader,
    test_dataloader,
    model,
    optim_state,
    epochs,
    test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model, optim_state = train_epoch(train_dataloader, model, optim_state)
        if (epoch % test_every == 0):
            test(test_dataloader, model)
        print(f"----------------")
    return model, optim_state

## Train MLP on the MNIST dataset.

In [13]:
with jax.default_matmul_precision("float32"):
    new_model, new_optim_state = train_loop(
        train_dataloader,
        test_dataloader,
        model,
        optim_state,
        30, 5
    )

Epoch 1
----------------
Train loss: 0.4857460558414459
----------------
Epoch 2
----------------
Train loss: 0.2093609720468521
----------------
Epoch 3
----------------
Train loss: 0.1522853970527649
----------------
Epoch 4
----------------
Train loss: 0.11907586455345154
----------------
Epoch 5
----------------
Train loss: 0.0976567417383194
Test loss: 0.09684841334819794, accuracy: 0.9693000316619873
----------------
Epoch 6
----------------
Train loss: 0.0806221291422844
----------------
Epoch 7
----------------
Train loss: 0.06772962957620621
----------------
Epoch 8
----------------
Train loss: 0.058621108531951904
----------------
Epoch 9
----------------
Train loss: 0.050001949071884155
----------------
Epoch 10
----------------
Train loss: 0.04330332204699516
Test loss: 0.0692794993519783, accuracy: 0.9784000515937805
----------------
Epoch 11
----------------
Train loss: 0.03750257566571236
----------------
Epoch 12
----------------
Train loss: 0.032796185463666916
-------