# Annotated MNIST Example Augmented

This notebook is based on [Annotated MNIST](https://colab.sandbox.google.com/github/google/flax/blob/main/docs/getting_started.ipynb#scrollTo=KvuEA8Tw-MYa) example from [FLAX](https://github.com/google/flax)

The primary objective is to illustrate spatial partitioning.

We begin with the default example i.e. model without spatial partitioning.
And then we illustrate spatial partitioning version with JAX pjit API.

In [1]:
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

import os

# Optional [For profiling only]
os.environ['FLAX_PROFILE'] = '1'
server = jax.profiler.start_server(1234)



2022-11-02 17:37:46.589625: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2022-11-02 17:37:47.360589: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2022-11-02 17:37:47.360668: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


## Download and prepate dataset from TFDS

In [2]:
def get_datasets():
    ds_builder = tfds.builder('fashion_mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds

# Model definition

In [3]:
class CNN(nn.Module):
    """A simple CNN model."""

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

### Define cross entropy loss and metrics functions

In [4]:
def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [5]:
def compute_metrics(*, logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return {'loss': loss, 'accuracy': accuracy}

### Define the train_step 

Following is the un-partitioned version of the train_step. 
It works on a single batch. We create the gradiente of the loss function.
And then compute the loss and gradient of loss for the given batch and network state.

In [6]:
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = CNN().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=batch['label'])
    return state, metrics

Evaluation step function, although step is misleading here since we are not working with a batch but the entirity of the eval dataset. 

In [7]:
@jax.jit
def eval_step(params, eval_ds):
    logits = CNN().apply({'params': params}, eval_ds['image'])
    return compute_metrics(logits=logits, labels=eval_ds['label'])

Now we are ready to define a train epoch in terms of train_step function

In [8]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size
    
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }
    print(f"train epoch: {epoch}, loss: {epoch_metrics_np['loss']}, accuracy: {epoch_metrics_np['accuracy'] * 100}")
    return state

And eval_model in terms of eval_step function.

In [9]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

Create train state, achieves two objectives: 
1. Initialize the model
2. Extract the param tree (train state)

In [10]:
def create_train_state(rng, learning_rate, momentum):
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=params, tx=tx)

# Run training (with No Spatial Partioning)

In [11]:
train_ds, test_ds = get_datasets()

2022-11-02 17:37:52.317168: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2022-11-02 17:37:52.317206: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)


In [12]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [13]:
learning_rate = 0.1
momentum = 0.9

In [14]:
state = create_train_state(init_rng, learning_rate, momentum)

In [15]:
num_epochs = 2
batch_size = 2048

In [16]:
for epoch in range(1, num_epochs + 1):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    # Run an optimization step over a training batch
    if epoch == 1:
        jax.profiler.start_trace(log_dir='/home/sivaibhav/profile-log')
    state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    if epoch == 2:
        jax.profiler.stop_trace()
    # Evaluate on the test set after each training epoch
    test_loss, test_accuracy = eval_model(state.params, test_ds)
    print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))

tcmalloc: large alloc 1178992640 bytes == 0x14f7ea000 @  0x7f2ecbe00680 0x7f2ecbe21824 0x7f2d296390da 0x7f2d270544e0 0x7f2d27054459 0x7f2d26ed8d38 0x7f2d26ed8c39 0x7f2d26145908 0x7f2d26175acb 0x7f2d26b13e75 0x7f2d26b164b9 0x7f2d2923c665 0x7f2d292416ab 0x7f2d292484e5 0x7f2d29464dbe 0x7f2ecbbd4609 0x7f2ecbd0e163


train epoch: 1, loss: 2.324336290359497, accuracy: 31.84604048728943
 test epoch: 1, loss: 1.85, accuracy: 46.36
train epoch: 2, loss: 1.2135814428329468, accuracy: 59.375
 test epoch: 2, loss: 0.66, accuracy: 75.51


# Spatial Partitioning

Spatial partioning can be viewed as a special case of data parallelism where we parition input images along the X and Y dimension instead of the batch dimension (seen in common data parallelism).

In the following sections we will use JAX PJIT API to express spatial partitioning.
PJIT is a general purpose API which allows to express partioning intent for the inputs and outputs of a function.
This intent is then automatically propagated through the function graph (using XLA SPMD) to create a partitioned version of the function with almost no manual effort required to update the model. This gives a power way to express a variety of parallelisms including SPMD based tensor and pipeline parallelisms, fully sharded data parallelism and spatial partitioning.
For more details on PJIT please refer to [this tutorial].(https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html)

In [17]:
import jax
from jax.experimental import maps
from jax.experimental import PartitionSpec
from jax.experimental.pjit import pjit
import numpy as np
import functools

## Step-1 Define a device mesh
The device mesh is a logical view of the physical device array expressed with named axes.
These axes will subsequently be referenced to express partitioning intent or annotation.
It is important that each of these axis correspond to a physical torus (1-D or 2-D). 
In the current example we are working with a TPU device i.e. 4 chip configuration.
For larger slice shape awareness of [mesh topologies](https://cloud.google.com/tpu/docs/types-topologies) is very critical to construct the optimal device mesh. We recommend using [mesh utils](https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L221) in such scenarios.

In [18]:
mesh_shape = (2, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = maps.Mesh(devices, ('x', 'y'))
mesh

Mesh(array([[0, 1],
       [2, 3]]), ('x', 'y'))

## Step-2 Using PJIT to express partition intent
Create a version of train_step with partitioning intent on the inputs.
Notice that we are using 'x', 'y' mesh axis to define the partitioning for image's X and Y dimensions (spatial partitioning)

In [19]:
p_train_step = pjit(
    train_step,
    in_axis_resources=(None, 
                       {'image': PartitionSpec(None, 'x', 'y', None),
                        'label': None 
                       }
                      ),
    out_axis_resources=None
)

## Step-3 update other functions to use pjit version of train_step

In [20]:
def p_train_epoch(state, train_ds, batch_size, epoch, rng):
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size
    
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = p_train_step(state, batch)
        batch_metrics.append(metrics)
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }
    print(f"train epoch: {epoch}, loss: {epoch_metrics_np['loss']}, accuracy: {epoch_metrics_np['accuracy'] * 100}")
    return state

Express similar partition intent for eval_step

In [21]:
p_eval_step = pjit(
    eval_step,
    in_axis_resources=(None, 
                       {'image': PartitionSpec(None, 'x', 'y', None),
                        'label': None 
                       }
                      ),
    out_axis_resources=None
)
def p_eval_model(params, test_ds):
    metrics = p_eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']
    

# Run training with spatial partitioning

In [22]:
for epoch in range(1, num_epochs + 1):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    # Run an optimization step over a training batch
    if epoch == 1:
        jax.profiler.start_trace(log_dir='/home/sivaibhav/profile-log')
    with maps.Mesh(mesh.devices, mesh.axis_names):
        state = p_train_epoch(state, train_ds, batch_size, epoch, input_rng)
            # Evaluate on the test set after each training epoch
        test_loss, test_accuracy = p_eval_model(state.params, test_ds)
        print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
          epoch, test_loss, test_accuracy * 100))
    if epoch == 2:
        jax.profiler.stop_trace()


train epoch: 1, loss: 0.5606998801231384, accuracy: 78.46511602401733
 test epoch: 1, loss: 0.52, accuracy: 80.20
train epoch: 2, loss: 0.46477749943733215, accuracy: 82.48080611228943
 test epoch: 2, loss: 0.45, accuracy: 83.08
