# MLP Implementation in mlax.
To help with understaning, consider first reading the tutorial on
[Stateful Computations in JAX](https://jax.readthedocs.io/en/latest/jax-101/07-state.html),
especially the Linear Regression worked example. 

Also consider going over the Pytorch reference implementation in
`mlp_reference.ipynb`.

In [16]:
import jax
import jax.numpy as jnp
from jax import nn
from jax import random
from jax import tree_util
import torchvision
import numpy as np

We import the `Linear` block from `mlax.nn.blocks` to build the MLP.

We import the `categorical_crossentropy` loss function from `mlax.losses` to
evaluate our MLP.

We import the `sgd` optimizer to from `mlax.optim` calculate updates on the
model weights, and we import the `apply_gradient` function to apply those
updates.

In [17]:
from mlax.nn.blocks import Linear
from mlax.losses import categorical_crossentropy
from mlax.optim import sgd, apply_gradients

### Load the MNIST dataset from torchvision
Unlike in the reference implementation, `toTensor` transformation is not applied.
This is because we will be converting the datasets to numpy arrays.

In [18]:
mnist_train = torchvision.datasets.MNIST(
    root="data",
    train=True,
    download=True
)
mnist_test = torchvision.datasets.MNIST(
    root="data",
    train=False,
    download=True
)

### Convert the datasets into numpy arrays.

We could instead convert the datasets to JAX arrays. If we did so, the dataset
will be sent uncommitted to the default device, in my case, the GPU.
Read more about this [here](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices).

Since most datasets are too expensive to copy to to each accelerator, they
should stay on the CPU and be streamed to the accelerators during training and
testing. While we could do this by using `jax.device_put()`, it's simpler to convert them to numpy arrays, which always stay on the CPU.

In [19]:
X_train, y_train = mnist_train.data.numpy(), mnist_train.targets.numpy()
X_test, y_test = mnist_test.data.numpy(), mnist_test.targets.numpy()
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


### Batch the datasets
Data stay as numpy arrays, and therefore on the CPU.

In [20]:
def batch(x, y, batch_size, dtype="float32"):
  batched_x, batched_y = [], []
  for i in range(0, len(x), batch_size):
      batched_x.append(x[i:i+batch_size])
      batched_y.append(y[i:i+batch_size])
  return batched_x, batched_y

batch_size = 64
X_train, y_train = batch(X_train, y_train, batch_size)
X_test, y_test = batch(X_test, y_test, batch_size)
print(len(X_train), len(y_train))
print(len(X_test), len(y_test))

938 938
157 157


Note that there are more elegant ways to load in Tensorflow or Pytorch data.

Consider reading Jax's
[Training a Simple Neural Network with tensorflow/datasets Data Loading](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html) and
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html).

### Initialize MLP model weights.
`model_init` uses a `jax.random.PRNGKey` when initializing the weights.
Read more about random numbers in JAX [here](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html).

The initialized `model_weights` are JAX arrays, so they reside on the default
device, in my case the GPU.

Furthermore, default weights of `mlax` transformations and blocks are of the type
`float32`. You can override that with `init` functions' `dtype` parameter.

In [21]:
def model_init(key):
    key1, key2, key3 = random.split(key, 3)
    
    # Initialize weights for each linear block on the GPU, default type: float32
    w1 = Linear.init(key1, 28*28, 512)
    w2 = Linear.init(key2, 512, 512)
    w3 = Linear.init(key3, 512, 10)
    return [w1, w2, w3]

model_weights = model_init(random.PRNGKey(24))

### Define MLP dataflow
We first flatten the 2D numpy array with `jax.numpy.reshape`, converting it
into a JAX array , which gets sent to the default device, in my case the GPU.

We also explicitly convert its type to `float32`, the type of our
model weights.

Knowing what types we are working with is important because mlax functions do
not promote types implicitly.

We then pass the flattened inputs through several `mlax.nn.Linear` blocks, the
last of which ends with a `jax.nn.softmax` activation.

mlax functions, including `mlax.nn.Linear`, only accept a single unbatched
sample as input. So `model_fwd` only works on single unbatched samples. So we
use  JAX's [vmap](https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html)
to define a `batched_model_fwd`, which works on batched inputs.

In [22]:
@jax.jit
def model_fwd(
    x: np.array,
    weights: jnp.array
):
    # Flatten numpy array and send it to GPU.
    x = jnp.reshape(x, (-1, ))
    # Explicit type promotion for the following mlax functions
    x = x.astype("float32") 

    w1, w2, w3 = weights
    x = Linear.fwd(x, w1, nn.relu)
    x = Linear.fwd(x, w2, nn.relu)
    x = Linear.fwd(x, w3, nn.softmax)
    return x

batched_model_fwd = jax.vmap(model_fwd, in_axes=[0, None])


### Define loss function
`int8` numpy target values are first one-hot encoded and explicitly converted to
 a `float32` JAX array, which get sent to the default device, in my case the GPU.

Predicted probabilities are clipped to remove `0`s, which will cause `NaN` loss
when passed to the crossentropy function.

Like all mlax functions, `mlax.losses.categorical_crossentropy` only works on a
single unbatched sample. Again, we use JAX's `vmap` to change that.

We also divide each loss by the number of classes to match Pytorch's behavior.

The mean per-sample loss is returned.

In [23]:
@jax.jit
def batched_loss(
    batched_preds: jnp.array,
    batched_targets: np.array
):
    n_classes = 10
    # One-hot encode numpy targets and send to GPU, promoting them to float32
    batched_targets = nn.one_hot(batched_targets, n_classes, dtype="float32")
    # Clip predicted probabilities to remove 0s.
    batched_preds = jnp.clip(batched_preds, 1e-7, 1 - 1e-7)

    # Calculate per-sample loss
    losses = jax.vmap(categorical_crossentropy)(
        batched_preds, batched_targets
    ) 
    # Match Pytorch behavior, optional
    losses = losses / n_classes
    return losses.mean()

We also define a convenience function that calculates model loss based on datasets' batched inputs.

In [24]:
def batched_model_loss(
    x_batch: np.array,
    y_batch: np.array,
    weights: jnp.array
):
    return batched_loss(batched_model_fwd(x_batch, weights), y_batch)

### Define optimizer state and function
As with model weights, the optimzer state is a JAX array that gets sent to the
default device.

During initialization, the optimizer infers the type of its state from the
model weights. In our case then, the optimizer state is also of dtype `float32`.

We create an optimizer function that takes in gradients and an optimizer state,
and  returns new gradients to be applied and a new optimizer state.

Note we used `jax.tree_util.Partial` instead of `functools.Partial` to allow
using this function in `jit` compiled functions. Read more about this [here](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.Partial.html?highlight=Partial)

In [25]:
# Initialize optimizer state on the GPU
optim_state = sgd.init(model_weights)
optim_fn = tree_util.Partial(sgd.step, lr=1e-2, momentum=0.6)

### Define training step
We first calculate the batch loss, which shall be used for logging.

We then use JAX's `grad` to calculate the
gradients with repect to `model_weights`. Read more about JAX's autodiff
[here](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#taking-derivatives-with-grad).

We use the `optim_fn` to get new gradients and a new optimizer state.

We apply the new gradients to update the model weights.

Finally, we return the batch loss, new `model_weights`, and new `optim_state`.

In [26]:
@jax.jit
def train_step(
    x_batch: np.array, 
    y_batch: np.array,
    model_weights: jnp.array,
    optim_fn, # (gradients, optim_state) -> (new_gradients, new_optim_state)
    optim_state: jnp.array
):
    # Find batch loss, only useful for logging
    loss = batched_model_loss(x_batch, y_batch, model_weights)
    # Find gradients wrt model_weights (argument 2)
    gradients = jax.grad(batched_model_loss, argnums=2)(
        x_batch,
        y_batch,
        model_weights
    )

    # Get new gradients and optimizer state
    gradients, optim_state = optim_fn(gradients, optim_state)

    # Update model_weights with new gradients
    model_weights = apply_gradients(gradients, model_weights)
    return loss, model_weights, optim_state

### Define functions for training and testing loops

In [27]:
def train_epoch(X_train, y_train, model_weights, optim_fn, optim_state):
    num_batches = len(X_train)
    train_loss = 0.0
    for i in range(num_batches):
        x_batch, y_batch = X_train[i], y_train[i]
        loss, model_weights, optim_state = train_step(
            x_batch, y_batch,
            model_weights,
            optim_fn, optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / num_batches}") 
    return model_weights, optim_state

In [28]:
def test(X_test, y_test, model_weights):
    num_batches = len(X_test)
    test_loss, accuracy = 0, 0.0
    for i in range(num_batches):
        x_batch, y_batch = X_test[i], y_test[i]
        preds = batched_model_fwd(
            x_batch, model_weights
        )
        loss = batched_loss(preds, y_batch)
        test_loss += loss
        accuracy += (jnp.argmax(preds, axis=1) == y_batch).sum() / len(x_batch)
    
    print(f"Test loss: {test_loss / num_batches}, accuracy: {accuracy / num_batches}")

In [29]:
def train_loop(
    X_train, y_train,
    X_test, y_test,
    model_weights,
    optim_fn, optim_state,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model_weights, optim_state = train_epoch(
            X_train, y_train,
            model_weights,
            optim_fn, optim_state
        )
        if (epoch % test_every == 0):
            test(X_test, y_test, model_weights)
        print(f"----------------")
    
    return model_weights, optim_state

### Train MLP on MNIST dataset
Achieves an accuracy of ~98%, similar to the Pytorch reference's.

The mlax MLP does this in only 1 minutes 18 seconds as opposed to 4 minutes.

Although this is an unfair comparison because we used `jax.jit` in strategic
places to jit-compile functions using XLA. Read more about this
[here](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#using-jit-to-speed-up-functions).

In [30]:
new_model_weights, new_optim_state = train_loop(
    X_train, y_train,
    X_test, y_test,
    model_weights,
    optim_fn, optim_state,
    50, 5
)

Epoch 1
----------------
Train loss: 1.0663982629776
----------------
Epoch 2
----------------
Train loss: 0.781200647354126
----------------
Epoch 3
----------------
Train loss: 0.7016758322715759
----------------
Epoch 4
----------------
Train loss: 0.6889051795005798
----------------
Epoch 5
----------------
Train loss: 0.6860621571540833
Test loss: 0.6802540421485901, accuracy: 0.5765326619148254
----------------
Epoch 6
----------------
Train loss: 0.5907200574874878
----------------
Epoch 7
----------------
Train loss: 0.5317518711090088
----------------
Epoch 8
----------------
Train loss: 0.400712251663208
----------------
Epoch 9
----------------
Train loss: 0.3803996741771698
----------------
Epoch 10
----------------
Train loss: 0.3331351578235626
Test loss: 0.2388990819454193, accuracy: 0.846337616443634
----------------
Epoch 11
----------------
Train loss: 0.23317430913448334
----------------
Epoch 12
----------------
Train loss: 0.22237780690193176
----------------
Epoch

### Bringing it all together

We first created a dataset that resides on the CPU.

We then created a model whose weights reside on the GPU, but streams in
samples from the CPU to work on. We made sure to explicitly convert the dtype of
the streamed-in samples to match the dtype of model weights.

We used `jax.vmap` on the model, so that it accepts batched inputs.

We defined a loss function, which streams in targets from the CPU. We also made
sure to explicitly convert the dtype of streamed-in targets to match the dtype
of model predictions.

We used `jax.vmap` on the loss function, so that it accepts batched predictions
and targets.

We then created an optimizer whose state is on the GPU, which operates on
gradients also on the GPU.

We used said optimizer to update the model weights, while also getting a new
optimizer state.

We repeated the predict-evaluate-optimize cycle with the new model weights and
the new optimizer state over many batches until we get our desired model.

Finally, we used `jax.jit` to compile each training step, which massively sped
up our computations.