# MLP Implementation in mlax without Optax optimizers.
This notebook just uses the `mlax` package.

You can view the Pytorch reference implementation in `mlp_reference.ipynb`.

In [1]:
import jax
from jax import (
    numpy as jnp,
    nn,
    random
)
import numpy as np
import torchvision
from torch.utils.data import DataLoader

In [2]:
from mlax.nn import Series, Linear, Bias, F

In [3]:
# Local python file containing an SGD optimizer written in JAX.
from optim import (
    sparse_categorical_crossentropy,
    sgd_init,
    sgd_step,
    apply_updates
)

### Load in and batch the MNIST datasets.
We follow this example
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) in using Pytorch dataloaders.

In [4]:
class ToNumpy:
  def __call__(self, pic):
    return np.array(pic)

mnist_train = torchvision.datasets.MNIST(
    root="../data",
    train=True,
    download=True,
    transform=ToNumpy()
)
mnist_test = torchvision.datasets.MNIST(
    root="../data",
    train=False,
    download=True,
    transform=ToNumpy()
)
print(mnist_train.data.shape)
print(mnist_test.data.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])


In [5]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

batch_size=128
train_dataloader = DataLoader(
    mnist_train, batch_size=128, shuffle=True, collate_fn=numpy_collate, num_workers=6
)
test_dataloader = DataLoader(
    mnist_test, batch_size=128, collate_fn=numpy_collate, num_workers=6
)
print(len(train_dataloader), len(test_dataloader))

469 79


## Build MLP using `mlax.module`

In [6]:
keys_iter = iter([random.fold_in(random.PRNGKey(0), i) for i in range(6)])
model = Series([
    F(lambda x: jnp.reshape(x.astype(jnp.float32) / 255.0, (-1,))),  # Flatten and scale
    Linear(next(keys_iter), out_features=512),
    Bias(next(keys_iter), in_features=512),
    F(nn.relu),
    Linear(next(keys_iter), out_features=512),
    Bias(next(keys_iter), in_features=512),
    F(nn.relu),
    Linear(next(keys_iter), out_features=10),
    Bias(next(keys_iter), in_features=10),
    F(nn.softmax)
])

# Induce lazy initialization
for X, _ in train_dataloader:
    activations, _ = model(X[0], None, inference_mode=True)
    print(activations.shape)
    break

(10,)


### Define loss function.

In [7]:
loss_fn = sparse_categorical_crossentropy

### Define optimizer.

In [8]:
optim_state = sgd_init(model.filter())

### Define training and testing steps.

In [9]:
@jax.jit
def train_step(X, y, model, optim_state):
    def _model_loss(X, y, trainables, non_trainables):
        model = trainables.combine(non_trainables)
        preds, model = jax.vmap(
            model.__call__,
            in_axes = (0, None, None, None),
            out_axes = (0, None),
            axis_name = "N"
        )(X, None, False, "N")
        return loss_fn(preds, y), model

    # Find batch loss and gradients with resect to trainables
    trainables, non_trainables = model.partition()
    (loss, model), gradients = jax.value_and_grad(
        _model_loss,
        argnums=2, # gradients wrt trainables (argument 2 of model_training_loss)
        has_aux=True # model is auxiliary data, loss is the true ouput
    )(X, y, trainables, non_trainables)

    # Get new gradients and optimizer state
    gradients, optim_state = sgd_step(gradients, optim_state)

    # Update parameters with new gradients
    trainables, non_trainables = model.partition()
    trainables = apply_updates(gradients, trainables)
    return loss, trainables.combine(non_trainables), optim_state

In [10]:
@jax.jit
def test_step(X, y, model):
    preds, _ = jax.vmap(
        model.__call__,
        in_axes = (0, None, None, None),
        out_axes = (0, None),
        axis_name = "N"
    )(X, None, True, "N")
    accurate = (jnp.argmax(preds, axis=1) == y).sum()
    return loss_fn(preds, y), accurate

### Define training and testing loops.

In [11]:
def train_epoch(dataloader, model, optim_state):
    train_loss = 0.0
    for X, y in dataloader:
        loss, model, optim_state = train_step(X, y, model, optim_state)
        train_loss += loss

    print(f"Train loss: {train_loss / len(dataloader)}") 
    return model, optim_state

In [12]:
def test(dataloader, model):
    test_loss, accurate = 0.0, 0
    for X, y in dataloader:
        loss, acc = test_step(X, y, model)
        test_loss += loss
        accurate += acc
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [13]:
def train_loop(
    train_dataloader,
    test_dataloader,
    model,
    optim_state,
    epochs,
    test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model, optim_state = train_epoch(train_dataloader, model, optim_state)
        if (epoch % test_every == 0):
            test(test_dataloader, model)
        print(f"----------------")
    return model, optim_state

## Train MLP on the MNIST dataset.

In [14]:
with jax.default_matmul_precision("float32"):
    new_params, new_optim_state = train_loop(
        train_dataloader,
        test_dataloader,
        model,
        optim_state,
        30, 5
    )

Epoch 1
----------------
Train loss: 0.43574488162994385
----------------
Epoch 2
----------------
Train loss: 0.19324670732021332
----------------
Epoch 3
----------------
Train loss: 0.1411539912223816
----------------
Epoch 4
----------------
Train loss: 0.11070062220096588
----------------
Epoch 5
----------------
Train loss: 0.08962983638048172
Test loss: 0.09407661855220795, accuracy: 0.9695000648498535
----------------
Epoch 6
----------------
Train loss: 0.0755179151892662
----------------
Epoch 7
----------------
Train loss: 0.0630074217915535
----------------
Epoch 8
----------------
Train loss: 0.054314710199832916
----------------
Epoch 9
----------------
Train loss: 0.04635758325457573
----------------
Epoch 10
----------------
Train loss: 0.04083261638879776
Test loss: 0.06593061238527298, accuracy: 0.9785000681877136
----------------
Epoch 11
----------------
Train loss: 0.034974969923496246
----------------
Epoch 12
----------------
Train loss: 0.030560683459043503
----