# MLP Implementation without Optax.
Optax is a JAX optimization library by DeepMind. It is now the recommended way
to update weights in mlax. See `mlp_optax.ipynb` for this same example but using
Optax loss functions and optimizers.

This notebook uses the now deprecated SGD optimizer available in mlax before
prior to version 1.0.0.

A Pytorch reference implementation is available at `mlp_reference.ipynb`.

In [1]:
import jax
import jax.numpy as jnp
from jax import (
    nn,
    random,
    tree_util
)
import numpy as np

We import `Linear`, `Bias`, and `F` transformations from `mlax.nn` to build some
dense layers.

We import `Series` from `mlax.block` to stack the dense and dropout layers into
an MLP.

In [2]:
from mlax.nn import Linear, Bias, F
from mlax.block import Series 

We import helpers to load data from Pytorch.

We also import a categorical crossentropy loss function and a SGD optimizer.

In [3]:
from dataloader import batch, load_mnist
from optim import (
    sparse_categorical_crossentropy,
    sgd_init,
    sgd_step,
    apply_updates
)

### Load in and batch the MNIST datasets.
We use helper functions to load in Pytorch datasets as numpy and convert them in
to lists containing the batches.

Checkout
[Training a Simple Neural Network with tensorflow/datasets Data Loading](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html) and
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html)
for other ways to load in Tensorflow and Pytorch datsets.

In [4]:
# Load in datasets with helper
(X_train, y_train), (X_test, y_test) = load_mnist("../data")
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

# Batch datasets
batch_size = 128
X_train, y_train = batch(X_train, y_train, batch_size)
X_test, y_test = batch(X_test, y_test, batch_size)
print(len(X_train), len(y_train))
print(len(X_test), len(y_test))

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
469 469
79 79


### Initialize MLP model parameters.
`model_init` consumes a `jax.random.PRNGKey` when initializing the parameters.
Read more about random numbers in JAX [here](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html).

`F` is a wrapper around a stateless function. `Linear` is a linear
transformation without bias. `Bias` adds a bias term.

Each `init` function returns a `trainables`, `non_trainables`, and
`hyperparams`.

`trainables` are trainable weights. `non_trainables` are non-trainable
variables. `hyperparams` are additional parameters required by the forward pass.

The `trainables` and `non_trainables` are PyTrees of JAX arrays. Read more about
JAX PyTrees [here](https://jax.readthedocs.io/en/latest/pytrees.html).

`hyperparams` is a NamedTuple containing Python types.

In [5]:
def model_init(key):
    keys_iter = iter(random.split(key, 6))
    return Series.init(
        # Convert int8 numpy inputs to float32 JAX arrays and flatten them
        F.init(lambda x:jnp.reshape(
            jnp.asarray(x, jnp.float32) / 255.0,
            (len(x), -1))
        ),

        # Dense layer with relu activation
        Linear.init(
            key=next(keys_iter),
            in_features=28 * 28,
            out_features=512,
        ),
        Bias.init(
            key=next(keys_iter),
            in_feature_shape=(512,)
        ),
        F.init(nn.relu),

        # Dense layer with relu activation
        Linear.init(next(keys_iter), 512, 512),
        Bias.init(next(keys_iter), (512,)),
        F.init(nn.relu),
        
        # Dense layer with softmax
        Linear.init(next(keys_iter), 512, 10),
        Bias.init(next(keys_iter), (10,)),
        F.init(nn.softmax),
    )

trainables, non_trainables, hyperparams = model_init(random.PRNGKey(0))

### Define MLP dataflow.
`Series.fwd` takes in batched input features and tuples of `trainables`,
`non_trainables`, and `hyperparams`. It figures out which layer each
`hyperparams` is for, and calls their forward pass functions on the input
features in sequence.

It returns the model predictions and updated `non_trainables`.

In [6]:
model_fwd = Series.fwd

### Define loss function.

In [7]:
loss_fn=sparse_categorical_crossentropy

We define two convenience functions.

``model_training_loss`` returns the batch loss and updated `non_trainables` from
batched inputs and targets.

``model_inference_preds_loss`` returns the predictions and batch loss from
batched inputs and targets.

We jit-compile the ``model_inference_preds_loss`` for significant speedups. Note
that `hyperparams` is a static argument because it is made of Python types, not
valid JAX types, and it also used interally for control flow. Read more about
jit-compilation [here](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html).

In [8]:
def model_trainig_loss(
    x_batch: np.array,
    y_batch: np.array,
    trainables,
    non_trainables,
    hyperparams,
):
    preds, non_trainables = model_fwd(
        x_batch, trainables, non_trainables, hyperparams
    )
    return loss_fn(preds, y_batch), non_trainables

@tree_util.Partial(jax.jit, static_argnames="hyperparams")
def model_inference_preds_loss(
    x_batch: np.array,
    y_batch: np.array,
    trainables,
    non_trainables,
    hyperparams
):
    preds, _ = model_fwd(
        x_batch, trainables, non_trainables, hyperparams
    )
    return preds, loss_fn(preds, y_batch)

### Define optimizer state and function.
We pass the `trainables` to `sgd_init` to initialize an optimizer state.

We create a function that takes in `trainable` gradients and an `optim_state`,
and returns updates to be applied to the `trainables` and a new `optim_state`.

Note we used `jax.tree_util.Partial` instead of `functools.Partial` when
defining the `optim_fn` to allow to be passed to jit-compiled functions, notably
`train_step`.
Read more about this [here](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.Partial.html?highlight=Partial).

In [9]:
optim_state = sgd_init(trainables)
optim_fn = tree_util.Partial(sgd_step, lr=1e-2, momentum=0.9)

### Define training step.
We  use JAX's `value_and_grad` to calculate the batch loss and
gradients with respect to the `trainables`. Read more about JAX's autodiff
[here](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#taking-derivatives-with-grad).

The batch loss is only used for logging, but the gradients are passed to
`optim_fn` to get update gradients and a new `optim_state`.

We apply the update gradient on the model weights.

Finally, we return the batch loss, new `trainables`, and the new `optim_state`.

In [10]:
@tree_util.Partial(jax.jit, static_argnames="hyperparams")
def train_step(
    x_batch: np.array, 
    y_batch: np.array,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state
):
    # Find batch loss and gradients
    (loss, non_trainables), gradients = jax.value_and_grad(
        model_trainig_loss,
        argnums=2, # gradients wrt trainables (argument 2 of model_loss)
        has_aux=True # non_trainables is auxiliary data, loss is the true ouput
    )(x_batch, y_batch, trainables, non_trainables, hyperparams)

    # Get new gradients and optimizer state
    gradients, optim_state = optim_fn(gradients, optim_state)

    # Update model_weights with new gradients
    trainables = apply_updates(gradients, trainables)
    return loss, trainables, non_trainables, optim_state

### Define functions for training and testing loops.

In [11]:
def train_epoch(
    X_train, y_train,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state
):
    num_batches = len(X_train)
    train_loss = 0.0
    for i in range(num_batches):
        x_batch, y_batch = X_train[i], y_train[i]
        loss, trainables, non_trainables, optim_state = train_step(
            x_batch, y_batch,
            trainables, non_trainables, hyperparams,
            optim_fn, optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / num_batches}") 
    return trainables, non_trainables, optim_state

In [12]:
def test(
    X_test, y_test,
    trainables, non_trainables, hyperparams
):
    num_batches = len(X_test)
    test_loss, accuracy = 0, 0.0
    for i in range(num_batches):
        x_batch, y_batch = X_test[i], y_test[i]
        preds, loss = model_inference_preds_loss(
            x_batch, y_batch, trainables, non_trainables, hyperparams
        )
        test_loss += loss
        accuracy += (jnp.argmax(preds, axis=1) == y_batch).sum() / len(x_batch)
    
    print(f"Test loss: {test_loss / num_batches}, accuracy: {accuracy / num_batches}")

In [13]:
def train_loop(
    X_train, y_train,
    X_test, y_test,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        trainables, non_trainables, optim_state = train_epoch(
            X_train, y_train,
            trainables, non_trainables, hyperparams,
            optim_fn, optim_state
        )
        if (epoch % test_every == 0):
            test(X_test, y_test, trainables, non_trainables, hyperparams)
        print(f"----------------")
    
    return trainables, non_trainables, optim_state

### Train MLP on the MNIST dataset.

In [14]:
new_trainables, new_non_trainables, new_optim_state = train_loop(
    X_train, y_train,
    X_test, y_test,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state,
    30, 5
)

Epoch 1
----------------
Train loss: 0.45100072026252747
----------------
Epoch 2
----------------
Train loss: 0.2034211903810501
----------------
Epoch 3
----------------
Train loss: 0.14783120155334473
----------------
Epoch 4
----------------
Train loss: 0.11520185321569443
----------------
Epoch 5
----------------
Train loss: 0.09326479583978653
Test loss: 0.09866884350776672, accuracy: 0.9701344966888428
----------------
Epoch 6
----------------
Train loss: 0.0771799311041832
----------------
Epoch 7
----------------
Train loss: 0.06497106701135635
----------------
Epoch 8
----------------
Train loss: 0.055386949330568314
----------------
Epoch 9
----------------
Train loss: 0.04769798368215561
----------------
Epoch 10
----------------
Train loss: 0.041314348578453064
Test loss: 0.07340061664581299, accuracy: 0.9761669635772705
----------------
Epoch 11
----------------
Train loss: 0.03593740612268448
----------------
Epoch 12
----------------
Train loss: 0.03138212487101555
----