<a href="https://colab.research.google.com/github/soumik12345/examples/blob/master/colabs/jax/Simple_Training_Loop_in_JAX_and_Flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Writing a Simple Training Loop in JAX and FLAX
<!--- @wandbcode{stylegan-nada-colab} -->

In [None]:
!pip install -q wandb flax

[K     |████████████████████████████████| 1.8 MB 7.7 MB/s 
[K     |████████████████████████████████| 197 kB 53.3 MB/s 
[K     |████████████████████████████████| 146 kB 65.7 MB/s 
[K     |████████████████████████████████| 181 kB 44.6 MB/s 
[K     |████████████████████████████████| 63 kB 1.5 MB/s 
[K     |████████████████████████████████| 145 kB 71.7 MB/s 
[K     |████████████████████████████████| 217 kB 74.7 MB/s 
[K     |████████████████████████████████| 596 kB 68.0 MB/s 
[K     |████████████████████████████████| 51 kB 5.7 MB/s 
[K     |████████████████████████████████| 72 kB 516 kB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
import jax
import jax.numpy as jnp

import optax

from flax import linen as nn
from flax.training import train_state
from flax.serialization import (
    to_state_dict, msgpack_serialize, from_bytes
)

import os
import wandb
import numpy as np
from typing import Callable
from tqdm.notebook import tqdm

import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
wandb.init(
    project="simple-training-loop",
    entity="jax-series",
    job_type="simple-train-loop"
)

config = wandb.config
config.seed = 42
config.batch_size = 64
config.validation_split = 0.2
config.pooling = "avg"
config.learning_rate = 1e-4
config.epochs = 15

MODULE_DICT = {
    "avg": nn.avg_pool,
    "max": nn.max_pool,
}

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
(full_train_set, test_dataset), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

def normalize_img(image, label):
    image = tf.cast(image, tf.float32) / 255.
    return image, label

full_train_set = full_train_set.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE
)

num_data = tf.data.experimental.cardinality(
    full_train_set
).numpy()
print("Total number of data points:", num_data)
train_dataset = full_train_set.take(
    num_data * (1 - config.validation_split)
)
val_dataset = full_train_set.take(
    num_data * (config.validation_split)
)
print(
    "Number of train data points:",
    tf.data.experimental.cardinality(train_dataset).numpy()
)
print(
    "Number of val data points:",
    tf.data.experimental.cardinality(val_dataset).numpy()
)

train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(
    tf.data.experimental.cardinality(train_dataset).numpy()
)
train_dataset = train_dataset.batch(config.batch_size)

val_dataset = val_dataset.cache()
val_dataset = val_dataset.shuffle(
    tf.data.experimental.cardinality(val_dataset).numpy()
)
val_dataset = val_dataset.batch(config.batch_size)


test_dataset = test_dataset.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE
)
print(
    "Number of test data points:",
    tf.data.experimental.cardinality(test_dataset).numpy()
    )
test_dataset = test_dataset.cache()
test_dataset = test_dataset.batch(config.batch_size)

[1mDownloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /root/tensorflow_datasets/cifar10/3.0.2...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incomplete1MSV7T/cifar10-train.tfrecord


  0%|          | 0/50000 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incomplete1MSV7T/cifar10-test.tfrecord


  0%|          | 0/10000 [00:00<?, ? examples/s]

[1mDataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.[0m
Total number of data points: 50000
Number of train data points: 40000
Number of val data points: 10000
Number of test data points: 10000


In [None]:
class CNN(nn.Module):
    pool_module: Callable = nn.avg_pool

    def setup(self):
        self.conv_1 = nn.Conv(features=32, kernel_size=(3, 3))
        self.conv_2 = nn.Conv(features=32, kernel_size=(3, 3))
        self.conv_3 = nn.Conv(features=64, kernel_size=(3, 3))
        self.conv_4 = nn.Conv(features=64, kernel_size=(3, 3))
        self.conv_5 = nn.Conv(features=128, kernel_size=(3, 3))
        self.conv_6 = nn.Conv(features=128, kernel_size=(3, 3))
        self.dense_1 = nn.Dense(features=1024)
        self.dense_2 = nn.Dense(features=512)
        self.dense_output = nn.Dense(features=10)

    @nn.compact
    def __call__(self, x):
        x = nn.relu(self.conv_1(x))
        x = nn.relu(self.conv_2(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.relu(self.conv_3(x))
        x = nn.relu(self.conv_4(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.relu(self.conv_5(x))
        x = nn.relu(self.conv_6(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.relu(self.dense_1(x))
        x = nn.relu(self.dense_2(x))
        return self.dense_output(x)

In [None]:
rng = jax.random.PRNGKey(config.seed)
x = jnp.ones(shape=(config.batch_size, 32, 32, 3))
model = CNN(pool_module=MODULE_DICT[config.pooling])
params = model.init(rng, x)
jax.tree_map(lambda x: x.shape, params)

FrozenDict({
    params: {
        conv_1: {
            bias: (32,),
            kernel: (3, 3, 3, 32),
        },
        conv_2: {
            bias: (32,),
            kernel: (3, 3, 32, 32),
        },
        conv_3: {
            bias: (64,),
            kernel: (3, 3, 32, 64),
        },
        conv_4: {
            bias: (64,),
            kernel: (3, 3, 64, 64),
        },
        conv_5: {
            bias: (128,),
            kernel: (3, 3, 64, 128),
        },
        conv_6: {
            bias: (128,),
            kernel: (3, 3, 128, 128),
        },
        dense_1: {
            bias: (1024,),
            kernel: (2048, 1024),
        },
        dense_2: {
            bias: (512,),
            kernel: (1024, 512),
        },
        dense_output: {
            bias: (10,),
            kernel: (512, 10),
        },
    },
})

In [None]:
nn.tabulate(model, rng)(x)

'\n\n'

In [None]:
def init_train_state(
    model, random_key, shape, learning_rate
) -> train_state.TrainState:
    variables = model.init(random_key, jnp.ones(shape))
    optimizer = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )


state = init_train_state(
    model, rng, (config.batch_size, 32, 32, 3), config.learning_rate
)
print(type(state))

<class 'flax.training.train_state.TrainState'>


In [None]:
def cross_entropy_loss(*, logits, labels):
    one_hot_encoded_labels = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(
        logits=logits, labels=one_hot_encoded_labels
    ).mean()

In [None]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [None]:
@jax.jit
def train_step(
    state: train_state.TrainState, batch: jnp.ndarray
):
    image, label = batch

    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image)
        loss = cross_entropy_loss(logits=logits, labels=label)
        return loss, logits

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=label)
    return state, metrics

In [None]:
@jax.jit
def eval_step(state, batch):
    image, label = batch
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=label)

In [None]:
def save_checkpoint(ckpt_path, state, epoch):
    with open(ckpt_path, "wb") as outfile:
        outfile.write(msgpack_serialize(to_state_dict(state)))
    artifact = wandb.Artifact(
        f'{wandb.run.name}-checkpoint', type='dataset'
    )
    artifact.add_file(ckpt_path)
    wandb.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])


def load_checkpoint(ckpt_file, state):
    artifact = wandb.use_artifact(
        f'{wandb.run.name}-checkpoint:latest'
    )
    artifact_dir = artifact.download()
    ckpt_path = os.path.join(artifact_dir, ckpt_file)
    with open(ckpt_path, "rb") as data_file:
        byte_data = data_file.read()
    return from_bytes(state, byte_data)


def accumulate_metrics(metrics):
    metrics = jax.device_get(metrics)
    return {
        k: np.mean([metric[k] for metric in metrics])
        for k in metrics[0]
    }

In [None]:
def train_and_evaluate(
    train_dataset,
    eval_dataset,
    test_dataset,
    state: train_state.TrainState,
    epochs: int,
):
    num_train_batches = tf.data.experimental.cardinality(train_dataset)
    num_eval_batches = tf.data.experimental.cardinality(eval_dataset)
    num_test_batches = tf.data.experimental.cardinality(test_dataset)
    
    for epoch in tqdm(range(1, epochs + 1)):

        best_eval_loss = 1e6
        
        train_batch_metrics = []
        train_datagen = iter(tfds.as_numpy(train_dataset))
        for batch_idx in range(num_train_batches):
            batch = next(train_datagen)
            state, metrics = train_step(state, batch)
            train_batch_metrics.append(metrics)
        
        train_batch_metrics = accumulate_metrics(train_batch_metrics)
        print(
            'TRAIN (%d/%d): Loss: %.4f, accuracy: %.2f' % (
                epoch, epochs, train_batch_metrics['loss'],
                train_batch_metrics['accuracy'] * 100
            )
        )

        eval_batch_metrics = []
        eval_datagen = iter(tfds.as_numpy(eval_dataset))
        for batch_idx in range(num_eval_batches):
            batch = next(eval_datagen)
            metrics = eval_step(state, batch)
            eval_batch_metrics.append(metrics)
        
        eval_batch_metrics = accumulate_metrics(eval_batch_metrics)
        print(
            'EVAL (%d/%d):  Loss: %.4f, accuracy: %.2f\n' % (
                epoch, epochs, eval_batch_metrics['loss'],
                eval_batch_metrics['accuracy'] * 100
            )
        )

        wandb.log({
            "Train Loss": train_batch_metrics['loss'],
            "Train Accuracy": train_batch_metrics['accuracy'],
            "Validation Loss": eval_batch_metrics['loss'],
            "Validation Accuracy": eval_batch_metrics['accuracy']
        }, step=epoch)

        if eval_batch_metrics['loss'] < best_eval_loss:
            save_checkpoint("checkpoint.msgpack", state, epoch)
    
    restored_state = load_checkpoint("checkpoint.msgpack", state)
    test_batch_metrics = []
    test_datagen = iter(tfds.as_numpy(test_dataset))
    for batch_idx in range(num_test_batches):
        batch = next(test_datagen)
        metrics = eval_step(restored_state, batch)
        test_batch_metrics.append(metrics)
    
    test_batch_metrics = accumulate_metrics(test_batch_metrics)
    print(
        'Test: Loss: %.4f, accuracy: %.2f' % (
            test_batch_metrics['loss'],
            test_batch_metrics['accuracy'] * 100
        )
    )

    wandb.log({
        "Test Loss": test_batch_metrics['loss'],
        "Test Accuracy": test_batch_metrics['accuracy']
    })
    
    return state, restored_state

In [None]:
state, best_state = train_and_evaluate(
    train_dataset,
    val_dataset,
    test_dataset,
    state,
    epochs=config.epochs,
)

  0%|          | 0/15 [00:00<?, ?it/s]

TRAIN (1/15): Loss: 1.7181, accuracy: 37.21
EVAL (1/15):  Loss: 1.4790, accuracy: 46.24

TRAIN (2/15): Loss: 1.4085, accuracy: 48.77
EVAL (2/15):  Loss: 1.3431, accuracy: 51.79

TRAIN (3/15): Loss: 1.2606, accuracy: 54.90
EVAL (3/15):  Loss: 1.1837, accuracy: 58.17

TRAIN (4/15): Loss: 1.1449, accuracy: 58.95
EVAL (4/15):  Loss: 1.0741, accuracy: 62.06

TRAIN (5/15): Loss: 1.0522, accuracy: 62.66
EVAL (5/15):  Loss: 1.0286, accuracy: 63.34

TRAIN (6/15): Loss: 0.9736, accuracy: 65.67
EVAL (6/15):  Loss: 0.9352, accuracy: 67.13

TRAIN (7/15): Loss: 0.8999, accuracy: 68.37
EVAL (7/15):  Loss: 0.8314, accuracy: 71.67

TRAIN (8/15): Loss: 0.8357, accuracy: 70.56
EVAL (8/15):  Loss: 0.7574, accuracy: 74.37

TRAIN (9/15): Loss: 0.7688, accuracy: 73.04
EVAL (9/15):  Loss: 0.7337, accuracy: 74.52

TRAIN (10/15): Loss: 0.7068, accuracy: 75.24
EVAL (10/15):  Loss: 0.6548, accuracy: 77.63

TRAIN (11/15): Loss: 0.6402, accuracy: 77.50
EVAL (11/15):  Loss: 0.5705, accuracy: 80.18

TRAIN (12/15): Lo

In [None]:
wandb.finish()

VBox(children=(Label(value='501.843 MB of 501.843 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
Test Accuracy,▁
Test Loss,▁
Train Accuracy,▁▃▃▄▅▅▅▆▆▆▇▇▇██
Train Loss,█▆▆▅▅▄▄▃▃▃▂▂▂▁▁
Validation Accuracy,▁▂▃▃▄▄▅▅▅▆▆▆▇██
Validation Loss,█▇▆▆▅▅▄▄▄▃▃▃▂▁▁

0,1
Test Accuracy,0.70402
Test Loss,0.92716
Train Accuracy,0.87332
Train Loss,0.36891
Validation Accuracy,0.91312
Validation Loss,0.27159
