<a href="https://colab.research.google.com/github/soumik12345/jax-series/blob/simple-train-loop/season_1/1-Simple-Training-Loop-in-JAX-and-Flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Writing a Simple Training Loop in JAX and FLAX
<!--- @wandbcode{jax-season-1-episode-1} -->

Welcome to the first exercise of our JAX journey. In this notebook, we will attempt to create a simple training and evaluation loop for a baseline image classification task using [JAX](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Optax](https://github.com/deepmind/optax) in an end-to-end manner. We would also demonstrate how to take advantage of experiment-tracking using [Weights & Biases](https://wandb.ai/site) for a Flax-based pipeline. Now, let us jump into the code...



## üì¶ Packages and Basic Setup

In [None]:
# Installing Flax and Weights & Biases
!pip install -q wandb flax

In [None]:
import jax
import jax.numpy as jnp

import optax

from flax import linen as nn
from flax.training import train_state
from flax.serialization import (
    to_state_dict, msgpack_serialize, from_bytes
)

import os
import wandb
import numpy as np
from typing import Callable
from tqdm.notebook import tqdm

import tensorflow as tf
import tensorflow_datasets as tfds

### ü™Ñ Setting up Weights & Biases üêù

We will now initialize a new Weights & Biases job by calling [`wandb.init`](https://docs.wandb.ai/guides/track/launch). This will create a new run in Weights & Biases and launches a background process to sync data. We will also sync all the configs of our experiments with the W&B run, which makes it far easier for us to reproduce the results of the experiment later.

#### Task 1

Initialize a Weights & Biases run by calling `wandb.init` and set the following parameters:
- `project` to be `"simple-training-loop"`
- `entity` to be `jax-series"`
- `job_type` to be `"submission"`

Next set the values of all the configs listed in the code cell to any desired values...

In [None]:
# Initializing a Weights & Biases Run
wandb.init(
    project=...,
    entity=...,
    job_type=...
)

# Setting up configs to be synced by the Weights & Biases run
config = wandb.config
config.seed = ...
config.batch_size = ...
config.validation_split = ... # A fractional number between 0 and 1
config.pooling = ... # Either "avg" for average pooling or "max" for max pooling 
config.learning_rate = ... # The learning rate
config.epochs = ... # The number of epochs you wish to train for

MODULE_DICT = {
    "avg": nn.avg_pool,
    "max": nn.max_pool,
}

## üíø The Dataset

JAX or Flax doesn't yet have a native API for building data loading pipelines. One could use either of torch.utils.data API and Torchvision datasets from PyTorch or tf.data API and Tensorflow Datasets from Tensorflow for the purpose of building an input pipeline. However, most JAX practitioners prefer to use the tf.data API for building data loading pipelines for JAX and Flax-based machine learning workflow. In this exercise, we will build a simple data loading pipeline for the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset using [Tensorflow Datasets](https://www.tensorflow.org/datasets) for Image Classification.

**TensorFlow Datasets** is a collection of datasets ready to use, with TensorFlow, JAX or other Python ML frameworks. All datasets are exposed as [`tf.data.Datasets`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) , enabling easy-to-use and high-performance input pipelines. Tensorflow Datasets also provide a large list of [ready-to-use datasets](https://www.tensorflow.org/datasets/overview). For a getting a quick introduction to the general usage of Tensorflow Datasets, one can refer to their [official quick start guide](https://www.tensorflow.org/datasets/overview).

In [None]:
(full_train_set, test_dataset), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

def normalize_img(image, label):
    image = tf.cast(image, tf.float32) / 255.
    return image, label

full_train_set = full_train_set.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE
)

num_data = tf.data.experimental.cardinality(
    full_train_set
).numpy()
print("Total number of data points:", num_data)
train_dataset = full_train_set.take(
    num_data * (1 - config.validation_split)
)
val_dataset = full_train_set.take(
    num_data * (config.validation_split)
)
print(
    "Number of train data points:",
    tf.data.experimental.cardinality(train_dataset).numpy()
)
print(
    "Number of val data points:",
    tf.data.experimental.cardinality(val_dataset).numpy()
)

train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(
    tf.data.experimental.cardinality(train_dataset).numpy()
)
train_dataset = train_dataset.batch(config.batch_size)

val_dataset = val_dataset.cache()
val_dataset = val_dataset.shuffle(
    tf.data.experimental.cardinality(val_dataset).numpy()
)
val_dataset = val_dataset.batch(config.batch_size)


test_dataset = test_dataset.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE
)
print(
    "Number of test data points:",
    tf.data.experimental.cardinality(test_dataset).numpy()
    )
test_dataset = test_dataset.cache()
test_dataset = test_dataset.batch(config.batch_size)

## ‚úçÔ∏è Time for Modelling

Initially introduced after v0.4.0 of Flax, the **linen API** makes it easy to build *modules* for various deep learning methods easily while also maintaining and respecting the functional paradigm and providing excellent support for JAX transformations such as `vmap`, `remat` or `scan`. Linen was created to allow for developers to still create python objects (such as `dataclasses` or Object -Oriented based Subclasses) but also go about the functional single-method manner.

If you are familiar with the Keras or PyTorch ecosystem, the Linen API has easy-to-understand analogies.

- Instead of a PyTorch `nn.Module` or a Tensorflow `keras.Model`, Flax has a `linen.Module`.

- In PyTorch or Keras subclassed models we define all submodules and layers under the `__init__` method. Flax has a similar method called setup() that we override.

- Instead of a forward in PyTorch models or a call in Keras, Flax has a `__call__` method.


Let us now define a very simple classification convolution-based neural network. Instead of some famous architecture, we'll create a simple custom architecture by subclassing `linen.Module`.


In [None]:
class CNN(nn.Module):
    pool_module: Callable = nn.avg_pool

    def setup(self):
        self.conv_1 = nn.Conv(features=32, kernel_size=(3, 3))
        self.conv_2 = nn.Conv(features=32, kernel_size=(3, 3))
        self.conv_3 = nn.Conv(features=64, kernel_size=(3, 3))
        self.conv_4 = nn.Conv(features=64, kernel_size=(3, 3))
        self.conv_5 = nn.Conv(features=128, kernel_size=(3, 3))
        self.conv_6 = nn.Conv(features=128, kernel_size=(3, 3))
        self.dense_1 = nn.Dense(features=1024)
        self.dense_2 = nn.Dense(features=512)
        self.dense_output = nn.Dense(features=10)

    @nn.compact
    def __call__(self, x):
        x = nn.relu(self.conv_1(x))
        x = nn.relu(self.conv_2(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.relu(self.conv_3(x))
        x = nn.relu(self.conv_4(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.relu(self.conv_5(x))
        x = nn.relu(self.conv_6(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.relu(self.dense_1(x))
        x = nn.relu(self.dense_2(x))
        return self.dense_output(x)

## üé¨ Initializing the Module

#### Task 2

Now that we have defined our convolutional neural network as a Flax `linen.module`, let us initialize the model and its parameters.

- Create a pseudo-random number generator using `jax.random.PRNGKey`. Set the seed to be `config.seed`.
- Initialize the aforementioned `CNN` module using the following parameters:
    - set the parameter `pool_module` to the correct module mapped in the dictionary `MODULE_DICT` as per the value of `config.pooling`.
    - Initialize the parameters with an input shape of `(config.batch_size, 32, 32, 3)`. 
- **[Optional]** Check the model and the parameter shapes in a tabular manner using the [`nn.tabulate`](https://flax.readthedocs.io/en/latest/flax.linen.html?highlight=tabulate) module. Shoutout to [Cristian Garcia](https://twitter.com/cgarciae88) for adding this amazing feature ü•≥.

In [None]:
rng = ... # Create a pseudo-random number generator

x = jnp.ones(shape=...) # Input shape
model = ... # Instantiate the model
params = ... # Initialize the parameters

nn.tabulate(model, rng)(x)

As discussed previously, the JAX ecosystem is based on a functional paradigm with pure functions. Our training and evaluation steps are simply functions that we will periodically call. But how do we update the parameters if there is no global context? Well, simple we introduce an intermediate variable that is transferred/passed at each function call.

That's what the `__init__` function does. It takes the module and returns the updated variables. After we initialize the model we'll use the variables to create a [**TrainState**](https://flax.readthedocs.io/en/latest/flax.training.html#flax.training.train_state.TrainState), a utility class for handling parameter and gradient updates. This is a key feature of the new Flax version. Instead of initializing the model again and again with new variables we just update the "state" of the model and pass this as inputs to functions.

Let's walk through how one would create a TrainState...

#### Task 3

Let us create a Flax [`TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#flax.training.train_state.TrainState) by:-

- Instantiating the Model using `model.init()`
- Creating the optimizer using `optax`, for example you could use `optax.adam()`
- Return a TrainState instance by providing the following parameters:
  - `apply_fn`: In our case, the `apply` method for the model
  - `tx`: the optax optimizer function
  - `params`: the parameters from the initialized model, (accessed via the "`params`" key)

In [None]:
def init_train_state(
    model, random_key, shape, learning_rate
) -> train_state.TrainState:
    # Initialize the Model
    variables = ...
    # Create the Optimizer
    optimizer = ...
    # Return a TrainState
    return train_state.TrainState.create(
        apply_fn = ...,
        tx=...,
        params=...
    )


state = init_train_state(
    model, rng, (config.batch_size, 32, 32, 3), config.learning_rate
)

### Utility Functions

We create two utility functions to calculate the Cross-Entropy Loss and compute the metrics given the logits and labels

In [None]:
def cross_entropy_loss(*, logits, labels):
    one_hot_encoded_labels = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(
        logits=logits, labels=one_hot_encoded_labels
    ).mean()

In [None]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

### üèãÔ∏è‚Äç‚ôÇÔ∏è Train Step

In this code cell we'll create a `train_step` function

#### Task 4

- Create the gradient function using `jax.value_and_grad()` (make sure to set the `has_aux` parameter to be True)
- Pass the state parameters to created the gradient function
- Perform backpropagation using the `.apply_gradients()` method of the TrainState by passing in the gradients in the `grads` parameter
- calculate the metrics using the `compute_metrics` utility function by passing in the logits and labels

In [None]:
@jax.jit
def train_step(
    state: train_state.TrainState, batch: jnp.ndarray
):
    image, label = batch

    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image)
        loss = cross_entropy_loss(logits=logits, labels=label)
        return loss, logits

    # Create Gradient Function by passing in the function
    gradient_fn = ...
    # Pass in the params from the TrainState
    (_, logits), grads = ...
    # Update Parameters
    state = ...
    # Compute Metrics
    metrics = compute_metrics(logits=..., labels=...)
    return state, metrics

### üßë‚Äç‚öñÔ∏è Eval Step

Similar to our `train_step` this function also takes the state and the batch. We simply perform a forward pass using the data and obtain the logits and then compute the corresponding metrics. As this is the `eval_step` we don't compute the gradients or update the parameters of the `TrainState`.

In [None]:
@jax.jit
def eval_step(state, batch):
    image, label = batch
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=label)

### W&B Artifacts

Using our Artifacts API, you can log artifacts as outputs of W&B runs, or use artifacts as input to runs.

![](https://1039519455-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-Lqya5RvLedGEWPhtkjU%2F-M94QAXA-oJmE6q07_iT%2F-M94QJCXLeePzH1p_fW1%2Fsimple%20artifact%20diagram%202.png?alt=media&token=94bc438a-bd3b-414d-a4e4-aa4f6f359f21)

Since a run can use another run‚Äôs output artifact as input, artifacts and runs together form a directed graph. You don‚Äôt need to define pipelines ahead of time. Just use and log artifacts, and we‚Äôll stitch everything together.
You can store different versions of your datasets and models in the cloud as Artifacts. Think of an Artifact as of a folder of data to which we can add individual files, and then upload to the cloud as a part of our W&B project, which also supports automatic versioning of datasets and models. Artifacts also track the training pipelines as DAGs. Here's an example of a artifacts graph.

![](https://i.imgur.com/QQULnpP.gif)

In [None]:
def save_checkpoint(ckpt_path, state, epoch):
    with open(ckpt_path, "wb") as outfile:
        outfile.write(msgpack_serialize(to_state_dict(state)))
    artifact = wandb.Artifact(
        f'{wandb.run.name}-checkpoint', type='dataset'
    )
    artifact.add_file(ckpt_path)
    wandb.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])


def load_checkpoint(ckpt_file, state):
    artifact = wandb.use_artifact(
        f'{wandb.run.name}-checkpoint:latest'
    )
    artifact_dir = artifact.download()
    ckpt_path = os.path.join(artifact_dir, ckpt_file)
    with open(ckpt_path, "rb") as data_file:
        byte_data = data_file.read()
    return from_bytes(state, byte_data)


def accumulate_metrics(metrics):
    metrics = jax.device_get(metrics)
    return {
        k: np.mean([metric[k] for metric in metrics])
        for k in metrics[0]
    }

## ü¶æ A JAX-based Train and Validation Loop

Now, let us create a simple training and validation loop for training our image classification model using JAX and Flax...

In [None]:
def train_and_validate(
    train_dataset,
    eval_dataset,
    test_dataset,
    state: train_state.TrainState,
    epochs: int,
):
    num_train_batches = tf.data.experimental.cardinality(train_dataset)
    num_eval_batches = tf.data.experimental.cardinality(eval_dataset)
    num_test_batches = tf.data.experimental.cardinality(test_dataset)
    
    for epoch in tqdm(range(1, epochs + 1)):

        best_eval_loss = 1e6
        
        train_batch_metrics = []
        train_datagen = iter(tfds.as_numpy(train_dataset))
        for batch_idx in range(num_train_batches):
            batch = next(train_datagen)
            state, metrics = train_step(state, batch)
            train_batch_metrics.append(metrics)
        
        train_batch_metrics = accumulate_metrics(train_batch_metrics)
        print(
            'TRAIN (%d/%d): Loss: %.4f, accuracy: %.2f' % (
                epoch, epochs, train_batch_metrics['loss'],
                train_batch_metrics['accuracy'] * 100
            )
        )

        eval_batch_metrics = []
        eval_datagen = iter(tfds.as_numpy(eval_dataset))
        for batch_idx in range(num_eval_batches):
            batch = next(eval_datagen)
            metrics = eval_step(state, batch)
            eval_batch_metrics.append(metrics)
        
        eval_batch_metrics = accumulate_metrics(eval_batch_metrics)
        print(
            'EVAL (%d/%d):  Loss: %.4f, accuracy: %.2f\n' % (
                epoch, epochs, eval_batch_metrics['loss'],
                eval_batch_metrics['accuracy'] * 100
            )
        )

        wandb.log({
            "Train Loss": train_batch_metrics['loss'],
            "Train Accuracy": train_batch_metrics['accuracy'],
            "Validation Loss": eval_batch_metrics['loss'],
            "Validation Accuracy": eval_batch_metrics['accuracy']
        }, step=epoch)

        if eval_batch_metrics['loss'] < best_eval_loss:
            save_checkpoint("checkpoint.msgpack", state, epoch)
    
    restored_state = load_checkpoint("checkpoint.msgpack", state)
    test_batch_metrics = []
    test_datagen = iter(tfds.as_numpy(test_dataset))
    for batch_idx in range(num_test_batches):
        batch = next(test_datagen)
        metrics = eval_step(restored_state, batch)
        test_batch_metrics.append(metrics)
    
    test_batch_metrics = accumulate_metrics(test_batch_metrics)
    print(
        'Test: Loss: %.4f, accuracy: %.2f' % (
            test_batch_metrics['loss'],
            test_batch_metrics['accuracy'] * 100
        )
    )

    wandb.log({
        "Test Loss": test_batch_metrics['loss'],
        "Test Accuracy": test_batch_metrics['accuracy']
    })
    
    return state, restored_state

In [None]:
state, best_state = train_and_evaluate(
    train_dataset,
    val_dataset,
    test_dataset,
    state,
    epochs=config.epochs,
)

In [None]:
wandb.finish()