# MLP Implementation in mlax without Optax optimizers.
This notebook just uses the `mlax` package.

You can view the Pytorch reference implementation in `mlp_reference.ipynb`.

In [1]:
import jax
from jax import (
    numpy as jnp,
    nn,
    random
)
import numpy as np
from functools import partial
import torchvision
from torch.utils.data import DataLoader

In [2]:
from mlax import Module
from mlax.nn import Series, Linear, Bias, F

In [3]:
# Local python file containing an SGD optimizer written in JAX.
from optim import (
    sparse_categorical_crossentropy,
    sgd_init,
    sgd_step,
    apply_updates
)

### Load in and batch the MNIST datasets.
We follow this example
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) in using Pytorch dataloaders.

In [4]:
class ToNumpy:
  def __call__(self, pic):
    return np.array(pic)

mnist_train = torchvision.datasets.MNIST(
    root="../data",
    train=True,
    download=True,
    transform=ToNumpy()
)
mnist_test = torchvision.datasets.MNIST(
    root="../data",
    train=False,
    download=True,
    transform=ToNumpy()
)
print(mnist_train.data.shape)
print(mnist_test.data.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])


In [5]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

batch_size=128
train_dataloader = DataLoader(
    mnist_train, batch_size=128, shuffle=True, collate_fn=numpy_collate, num_workers=6
)
test_dataloader = DataLoader(
    mnist_test, batch_size=128, shuffle=True, collate_fn=numpy_collate, num_workers=6
)
print(len(train_dataloader), len(test_dataloader))

469 79


## Build MLP using `mlax.module`

In [6]:
class MLP(Module):
    def __init__(self, rng):
        super().__init__()
        rngs_iter = iter(random.split(rng, 6))

        self.linear1 = Series([
            Linear(next(rngs_iter), out_features=512),
            Bias(next(rngs_iter), in_features=512),
            F(nn.relu)
        ])

        self.linear2 = Series([
            Linear(next(rngs_iter), out_features=512),
            Bias(next(rngs_iter), in_features=512),
            F(nn.relu)
        ])

        self.linear3 = Series([
            Linear(next(rngs_iter), out_features=10),
            Bias(next(rngs_iter), in_features=10),
            F(nn.softmax)
        ])

    @partial(
        jax.vmap,
        in_axes = (None, 0, None, None),
        out_axes = (0, None),
        axis_name = "batch"
    ) # Add leading batch dimension
    def __call__(self, x, rng=None, inference_mode=False):
        x = jnp.reshape(x.astype(jnp.float32) / 255.0, (-1,)) # Flatten and scale
        x, self.linear1 = self.linear1(x, rng=None, inference_mode=inference_mode)
        x, self.linear2 = self.linear2(x, rng=None, inference_mode=inference_mode)
        x, self.linear3 = self.linear3(x, rng=None, inference_mode=inference_mode)
        return x, self

model = MLP(random.PRNGKey(0))

# Induce lazy weight initialization
for x_batch, y_batch in train_dataloader:
    acts, model = model(x_batch, None, False)
    print(acts.shape)
    break

(128, 10)


### Define loss function.

In [7]:
loss_fn = sparse_categorical_crossentropy

We define two convenience functions.

``model_training_loss`` returns the batch loss and updated `model` from batched
inputs and targets.

``model_inference_preds_loss`` returns the batch loss and predictions from
batched inputs and targets.

In [8]:
def model_training_loss(
    x_batch: np.array,
    y_batch: np.array,
    model,
    trainables
):
    model = model.load_trainables(trainables)
    preds, model = model(x_batch, None, False)
    return loss_fn(preds, y_batch), model

@jax.jit
def model_inference_loss(
    x_batch: np.array,
    y_batch: np.array,
    model
):
    preds, _ = model(x_batch, None, True)
    return loss_fn(preds, y_batch), preds

### Define optimizer.

In [9]:
optim_state = sgd_init(model.trainables)

### Define training step.

In [10]:
@jax.jit
def train_step(
    x_batch: np.array, 
    y_batch: np.array,
    model,
    optim_state
):
    # Find batch loss and gradients with resect to trainables
    (loss, model), gradients = jax.value_and_grad(
        model_training_loss,
        argnums=3, # gradients wrt trainables (argument 3 of model_training_loss)
        has_aux=True # model is auxiliary data, loss is the true ouput
    )(x_batch, y_batch, model, model.trainables)

    # Get new gradients and optimizer state
    gradients, optim_state = sgd_step(
        gradients, optim_state, lr=1e-2, momentum=0.9
    )

    # Update model_weights with new gradients
    trainables = apply_updates(gradients, model.trainables)
    return loss, model.load_trainables(trainables), optim_state

### Define functions for training and testing loops.

In [11]:
def train_epoch(
    dataloader,
    model,
    optim_state
):
    num_batches = len(dataloader)
    train_loss = 0.0
    for x_batch, y_batch in dataloader:
        loss, model, optim_state = train_step(
            x_batch, y_batch,
            model,
            optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / num_batches}") 
    return model, optim_state

In [12]:
def test(
    dataloader,
    model
):
    num_batches = len(dataloader)
    test_loss, accuracy = 0, 0.0
    for x_batch, y_batch in dataloader:
        loss, preds = model_inference_loss(
            x_batch, y_batch, model
        )
        test_loss += loss
        accuracy += (jnp.argmax(preds, axis=1) == y_batch).sum() / len(x_batch)
    
    print(f"Test loss: {test_loss / num_batches}, accuracy: {accuracy / num_batches}")

In [13]:
def train_loop(
    train_dataloader,
    test_dataloader,
    model,
    optim_state,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model, optim_state = train_epoch(
            train_dataloader,
            model,
            optim_state
        )
        if (epoch % test_every == 0):
            test(test_dataloader, model)
        print(f"----------------")
    
    return model, optim_state

## Train MLP on the MNIST dataset.

In [14]:
new_model, new_optim_state = train_loop(
    train_dataloader,
    test_dataloader,
    model,
    optim_state,
    30, 5
)

Epoch 1
----------------
Train loss: 0.44252684712409973
----------------
Epoch 2
----------------
Train loss: 0.19888873398303986
----------------
Epoch 3
----------------
Train loss: 0.14479871094226837
----------------
Epoch 4
----------------
Train loss: 0.11203312128782272
----------------
Epoch 5
----------------
Train loss: 0.09040975570678711
Test loss: 0.0990489050745964, accuracy: 0.9703323245048523
----------------
Epoch 6
----------------
Train loss: 0.07537496834993362
----------------
Epoch 7
----------------
Train loss: 0.06405700743198395
----------------
Epoch 8
----------------
Train loss: 0.05468873679637909
----------------
Epoch 9
----------------
Train loss: 0.04696766287088394
----------------
Epoch 10
----------------
Train loss: 0.04060396924614906
Test loss: 0.07196811586618423, accuracy: 0.9769580960273743
----------------
Epoch 11
----------------
Train loss: 0.03543463721871376
----------------
Epoch 12
----------------
Train loss: 0.030744750052690506
----