# MLP Implementation in mlax with Optax optimizers.
To help with understaning, consider first reading the tutorial on
[Stateful Computations in JAX](https://jax.readthedocs.io/en/latest/jax-101/07-state.html),
especially the Linear Regression worked example. 

Also consider going over Optax's [Quick Start](https://optax.readthedocs.io/en/latest/optax-101.html).

You can view the Pytorch reference implementation in `mlp_reference.ipynb`.

See `mlp_no_optax.ipynb` for an implementation without using Optax.

See `mlp_optax_loss.ipynb` for an example using only Optax loss functions.

In [1]:
import jax
import jax.numpy as jnp
from jax import (
    nn,
    random,
    tree_util
)
import numpy as np
import optax

We import the `Linear` block from `mlax.nn.blocks` to build the MLP.

In [2]:
from mlax.blocks import Linear

We import helpers to load data.

In [3]:
from dataloader import batch, load_mnist

### Load the MNIST datasets as numpy arrays.
We could instead load the datasets as JAX arrays. If we did so, the dataset
will be sent uncommitted to the default device, in my case, the GPU.
Read more about this [here](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices).

Since most datasets are too expensive to copy to to each accelerator, they
should stay on the CPU and be streamed to the accelerators during training and
testing. While we could do this by using `jax.device_put()` on JAX arrays, it's
simpler to keep them as numpy arrays, which always stay on the CPU.

In [4]:
(X_train, y_train), (X_test, y_test) = load_mnist("../data")
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


### Batch the datasets
Data stay as numpy arrays, and therefore on the CPU.

In [5]:
batch_size = 64
X_train, y_train = batch(X_train, y_train, batch_size)
X_test, y_test = batch(X_test, y_test, batch_size)
print(len(X_train), len(y_train))
print(len(X_test), len(y_test))

938 938
157 157


Note that there are more elegant ways to load in Tensorflow or Pytorch data.

Consider reading Jax's
[Training a Simple Neural Network with tensorflow/datasets Data Loading](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html) and
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html).

### Initialize MLP model weights.
`model_init` uses a `jax.random.PRNGKey` when initializing the weights.
Read more about random numbers in JAX [here](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html).

The initialized `model_weights` are JAX arrays, so they reside on the default
device, in my case the GPU.

Furthermore, default weights of `mlax` transformations and blocks are of the type
`float32`. You can override that with `init` functions' `dtype` parameter.

In [6]:
def model_init(key):
    key1, key2, key3 = random.split(key, 3)
    
    # Initialize weights for each linear block on the GPU, default type: float32
    w1 = Linear.init(key1, 28*28, 512)
    w2 = Linear.init(key2, 512, 512)
    w3 = Linear.init(key3, 512, 10)
    return [w1, w2, w3]

model_weights = model_init(random.PRNGKey(24))

### Define MLP dataflow
We first flatten the 2D numpy array with `jax.numpy.reshape`, converting it
into a JAX array , which gets sent to the default device, in my case the GPU.

We also explicitly convert its type to `float32`, the type of our
model weights.

Knowing what types we are working with is important because mlax functions do
not promote types implicitly.

We then pass the flattened inputs through two `mlax.nn.Linear` blocks, but
the last one does end with an activation.

mlax functions, including `mlax.nn.Linear`, only accept a single unbatched
sample as input. So `model_fwd` only works on single unbatched samples. So we
use  JAX's [vmap](https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html)
to define a `batched_model_fwd`, which works on batched inputs.

In [7]:
@jax.jit
def model_fwd(
    x: np.array,
    weights: jnp.array
):
    # Flatten numpy array and send it to GPU.
    x = jnp.reshape(x, (-1, ))
    # Explicit type promotion for the following mlax functions
    x = x.astype("float32") 

    w1, w2, w3 = weights
    x = Linear.fwd(x, w1, nn.relu)
    x = Linear.fwd(x, w2, nn.relu)
    x = Linear.fwd(x, w3, None)
    return x

batched_model_fwd = jax.vmap(model_fwd, in_axes=[0, None])


### Define bach loss function
We use Optax's crossentropy loss, which is already vectorized (batched).

Note that target values are numpy arrays, which gets streamed from CPU and gets
converted into a JAX array on the GPU implicitly by Optax.

In [8]:
@jax.jit
def batched_loss(
    batched_preds: jnp.array,
    batched_targets: np.array
):
    return optax.softmax_cross_entropy_with_integer_labels(
        batched_preds,
        batched_targets
    ).mean()

We also define a convenience function that calculates the model loss based on
datasets' batched inputs.

In [9]:
def batched_model_loss(
    x_batch: np.array,
    y_batch: np.array,
    weights: jnp.array
):
    return batched_loss(batched_model_fwd(x_batch, weights), y_batch)

### Define optimizer state and function
We use Optax to create an optimizer.

We create an optimizer state and optimizer function.

Note we used `jax.tree_util.Partial` to wrap the update function. Doing this
allows the update function to be used in `jit` compiled functions.
Read more about this [here](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.Partial.html?highlight=Partial)

In [10]:
# Initialize optimizer state on the GPU
optimizer = optax.sgd(1e-3, momentum=0.7)
optim_state = optimizer.init(model_weights)
optim_fn = tree_util.Partial(optimizer.update)

### Define training step
We first calculate the batch loss, which will only be used for logging.

We then use JAX's `value_and_grad` to calculate the batch loss and the
gradients with repect to `model_weights`. Read more about JAX's autodiff
[here](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#taking-derivatives-with-grad).

The batch loss is only used for logging, but the gradients are passed to
`optim_fn` to get new gradients and a new optimizer state.

We apply the new gradients to update the model weights.

Finally, we return the batch loss, new `model_weights`, and new `optim_state`.

In [11]:
@jax.jit
def train_step(
    x_batch: np.array, 
    y_batch: np.array,
    model_weights: jnp.array,
    optim_fn, # (gradients, optim_state) -> (new_gradients, new_optim_state)
    optim_state: jnp.array
):
    # Find batch loss and gradients
    loss, gradients = jax.value_and_grad(
        batched_model_loss,
        argnums=2 # gradients wrt model_weights (argument 2)
    )(x_batch, y_batch, model_weights)

    # Get new gradients and optimizer state
    gradients, optim_state = optim_fn(gradients, optim_state)

    # Update model_weights with new gradients
    model_weights = optax.apply_updates(gradients, model_weights)
    return loss, model_weights, optim_state

### Define functions for training and testing loops

In [12]:
def train_epoch(X_train, y_train, model_weights, optim_fn, optim_state):
    num_batches = len(X_train)
    train_loss = 0.0
    for i in range(num_batches):
        x_batch, y_batch = X_train[i], y_train[i]
        loss, model_weights, optim_state = train_step(
            x_batch, y_batch,
            model_weights,
            optim_fn, optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / num_batches}") 
    return model_weights, optim_state

In [13]:
def test(X_test, y_test, model_weights):
    num_batches = len(X_test)
    test_loss, accuracy = 0, 0.0
    for i in range(num_batches):
        x_batch, y_batch = X_test[i], y_test[i]
        preds = batched_model_fwd(
            x_batch, model_weights
        )
        loss = batched_loss(preds, y_batch)
        test_loss += loss
        accuracy += (jnp.argmax(preds, axis=1) == y_batch).sum() / len(x_batch)
    
    print(f"Test loss: {test_loss / num_batches}, accuracy: {accuracy / num_batches}")

In [14]:
def train_loop(
    X_train, y_train,
    X_test, y_test,
    model_weights,
    optim_fn, optim_state,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model_weights, optim_state = train_epoch(
            X_train, y_train,
            model_weights,
            optim_fn, optim_state
        )
        if (epoch % test_every == 0):
            test(X_test, y_test, model_weights)
        print(f"----------------")
    
    return model_weights, optim_state

### Train MLP on MNIST dataset

In [15]:
new_model_weights, new_optim_state = train_loop(
    X_train, y_train,
    X_test, y_test,
    model_weights,
    optim_fn, optim_state,
    50, 5
)

Epoch 1
----------------
Train loss: 2.0185959339141846
----------------
Epoch 2
----------------
Train loss: 0.1646140217781067
----------------
Epoch 3
----------------
Train loss: 0.10156477242708206
----------------
Epoch 4
----------------
Train loss: 0.06812717765569687
----------------
Epoch 5
----------------
Train loss: 0.04950431361794472
Test loss: 0.1562071591615677, accuracy: 0.9608877301216125
----------------
Epoch 6
----------------
Train loss: 0.03605528175830841
----------------
Epoch 7
----------------
Train loss: 0.02721119113266468
----------------
Epoch 8
----------------
Train loss: 0.02074531279504299
----------------
Epoch 9
----------------
Train loss: 0.015856629237532616
----------------
Epoch 10
----------------
Train loss: 0.012181155383586884
Test loss: 0.14902129769325256, accuracy: 0.9658638834953308
----------------
Epoch 11
----------------
Train loss: 0.009694949723780155
----------------
Epoch 12
----------------
Train loss: 0.007761626038700342
---