# JAX & FLAX

> **Author**: Gustavo L. F. Walbon / **Date**: July 2022.

In this work we will have the inference of resnet18 with Jax engines.

## Objective
_In the fourth project you will have to implement the ResNet-18 model using JAX and Flax libraries for inference only. The FLAX library, discussed in the second lecture, provides primitives to stack multiple kinds of layers in order to form a neural network architecture._

_After you have implemented the model and loaded the weights, it is time to test your code on a few test images. You may use the images provided in [5](https://github.com/MO436-MC934/work). Finally, you job is to obtain the maximum amount of performance improvement from your network using JAX transformations: jit, vmap, pmap, etc. You may use your CPU, GPU or TPU (in Google Collab). In the end, you should write a report (PDF) describing how you implemented your model, which transformations you applied and why. If you could not apply some transformations, discuss the problems you found while trying to use it. Finally, do a performance analysis showing how fast your model has become compared to the non-transformed model (tables and graphs are welcome)._

_Reference: https://github.com/MO436-MC934/notebooks/wiki/5.JAX-Library_

### Loading Resnet18
Pre-trained network is used to see the weights of kernel and bias of Resnet.

In [83]:
import onnx
from onnx import numpy_helper

model = onnx.load("resnet18.onnx")

for initializer in model.graph.initializer:
    array = numpy_helper.to_array(initializer)
    if 'weight' in initializer.name:
        print(f"- Tensor: {initializer.name!r:45} shape={array.shape}")

- Tensor: 'resnetv15_conv0_weight'                      shape=(64, 3, 7, 7)
- Tensor: 'resnetv15_stage1_conv0_weight'               shape=(64, 64, 3, 3)
- Tensor: 'resnetv15_stage1_conv1_weight'               shape=(64, 64, 3, 3)
- Tensor: 'resnetv15_stage1_conv2_weight'               shape=(64, 64, 3, 3)
- Tensor: 'resnetv15_stage1_conv3_weight'               shape=(64, 64, 3, 3)
- Tensor: 'resnetv15_stage2_conv2_weight'               shape=(128, 64, 1, 1)
- Tensor: 'resnetv15_stage2_conv0_weight'               shape=(128, 64, 3, 3)
- Tensor: 'resnetv15_stage2_conv1_weight'               shape=(128, 128, 3, 3)
- Tensor: 'resnetv15_stage2_conv3_weight'               shape=(128, 128, 3, 3)
- Tensor: 'resnetv15_stage2_conv4_weight'               shape=(128, 128, 3, 3)
- Tensor: 'resnetv15_stage3_conv2_weight'               shape=(256, 128, 1, 1)
- Tensor: 'resnetv15_stage3_conv0_weight'               shape=(256, 128, 3, 3)
- Tensor: 'resnetv15_stage3_conv1_weight'               shape=(25

With that the next step is to create the dicionary with the same shape to be used in the interference.

### Creating de CNN

Resnet models has the format NCHW(Batch size, Channels, Height, Width), so that means the eg. (64, 3, 7, 7) refers to N=64, Channels=3, Height=7 and Width=7.

In [2]:
import jax.numpy as jnp
import jax.scipy as jsp
import jax
import jax.numpy as jnp
from functools import partial

In [3]:
import flax.linen as nn
import optax

# We need this to hold the training state
from flax.training.train_state import TrainState

from typing import Any, Callable, Sequence, Tuple

class ResNetBlock(nn.Module):
    """ResNet block."""
    filters: int
    conv: Any
    norm: Any
    act: Callable
    strides: Tuple[int, int] = (1, 1)

    @nn.compact
    def __call__(self, x,):
        residual = x
        y = self.conv(self.filters, (3, 3), self.strides)(x)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters, (3, 3))(y)
        y = self.norm(scale_init=nn.initializers.zeros)(y)

        if residual.shape != y.shape:
            residual = self.conv(self.filters, (1, 1),
                               self.strides)(residual)
            residual = self.norm()(residual)

        return self.act(residual + y)


class ResNet18(nn.Module):
    num_filters: int = 64
    dtype = jnp.float32
    conv = nn.Conv
    norm = nn.BatchNorm
    act = nn.relu
    
    
    @nn.compact
    def __call__(self, x):
        conv = partial(nn.Conv,
                    use_bias=False,
                    dtype=self.dtype
                   )
        norm = partial(nn.BatchNorm,
                       use_running_average=False,
                       momentum=0.9,
                       epsilon=1e-5,
                       dtype=self.dtype
                      )
        x = norm(name='batchnorm')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, #Input
                        window_shape=(3,3),
                        strides=(2,2),
                        padding='SAME')
        
        # -- Stage 1 weights (64, 64, 3, 3)--
        x = ResNetBlock(64, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(64, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(64, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(64, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        # -- Stage 2 weights (128, 64, 1, 1)--
        x = ResNetBlock(128, conv=conv,norm=norm, strides=(2,2), act=nn.relu)(x)
        x = ResNetBlock(128, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(128, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(128, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        # -- Stage 3 weights (256, 128, 1, 1)--
        x = ResNetBlock(256, conv=conv,norm=norm, strides=(2,2), act=nn.relu)(x)
        x = ResNetBlock(256, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(256, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(256, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        # -- Stage 4 weights (512, 256, 1, 1)--
        x = ResNetBlock(512, conv=conv,norm=norm, strides=(2,2), act=nn.relu)(x)
        x = ResNetBlock(512, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(512, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        x = ResNetBlock(512, conv=conv,norm=norm, strides=(1,1), act=nn.relu)(x)
        
        x = jnp.mean(x, axis=(1,2))
        x = nn.Dense(512, dtype=self.dtype)(x)
        x = jnp.asarray(x, self.dtype)
     
        return x

The function `__call__` was overwritten to be used internally by the Module.

### Evaluation Function
Next we define a function to evaluate the entire model and summarize the metrics.

In [6]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

In [9]:
def cross_entropy_loss(*, y_pred, y_label):
    """Compute cross-entropy loss given the activations and ground-truth labels."""
    # Labels are numeric, convert them to one hot
    y_label_one_hot = jax.nn.one_hot(y_label, num_classes=10)
    # Return cross-entropy loss
    return -jnp.mean(jnp.sum(y_label_one_hot * y_pred, axis=-1))

def compute_metrics(*, y_pred, y_label):
    """Compute metrics for a subset of data."""
    return {
        'loss': cross_entropy_loss(y_pred=y_pred, y_label=y_label),
        'accuracy': jnp.mean(jnp.argmax(y_pred, -1) == y_label),
    }

In [8]:
# Create random number generator
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

# Create train state
learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)

print(jax.tree_map(lambda x: x.shape, state.params))

2022-07-14 19:25:54.422226: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error


FrozenDict({
    Dense_0: {
        bias: (512,),
        kernel: (512, 512),
    },
    ResNetBlock_0: {
        BatchNorm_0: {
            bias: (64,),
            scale: (64,),
        },
        BatchNorm_1: {
            bias: (64,),
            scale: (64,),
        },
        BatchNorm_2: {
            bias: (64,),
            scale: (64,),
        },
        Conv_0: {
            kernel: (3, 3, 1, 64),
        },
        Conv_1: {
            kernel: (3, 3, 64, 64),
        },
        Conv_2: {
            kernel: (1, 1, 1, 64),
        },
    },
    ResNetBlock_1: {
        BatchNorm_0: {
            bias: (64,),
            scale: (64,),
        },
        BatchNorm_1: {
            bias: (64,),
            scale: (64,),
        },
        Conv_0: {
            kernel: (3, 3, 64, 64),
        },
        Conv_1: {
            kernel: (3, 3, 64, 64),
        },
    },
    ResNetBlock_10: {
        BatchNorm_0: {
            bias: (256,),
            scale: (256,),
        },
  

### Load MNIST Dataset
We will use Tensorflow to load the MNIST dataset.

In [10]:
import tensorflow_datasets as tfds

def load_mnist():
    """Load MNIST dataset using Tensorflow."""
    
    # Download data
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    
    # Create datasets
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    
    # Normalize images between 0.0 and 1.0
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    
    return train_ds, test_ds

In [11]:
# Load datasets
train_ds, test_ds = load_mnist()

# Get overview of both datasets
print("Train dataset:", jax.tree_map(lambda x: (x.shape, x.dtype), train_ds))
print(" Test dataset:", jax.tree_map(lambda x: (x.shape, x.dtype), test_ds))

2022-07-14 19:27:28.889113: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error


Train dataset: {'image': ((60000, 28, 28, 1), dtype('float32')), 'label': ((60000,), dtype('int64'))}
 Test dataset: {'image': ((10000, 28, 28, 1), dtype('float32')), 'label': ((10000,), dtype('int64'))}


## Training Step

Now it is time to define a training step on a single batch of images. This function will be executed thousands of times during the training loop, therefore it is interesting to JIT compile it. Do remember that by JIT compiling this function we are also compling everything that is called by `train_step`, in particular, `CNN.apply` and `loss_fn` will also be compiled.


In [7]:
def create_train_state(rng, learning_rate, momentum):
    # Instantiate model
    cnn = ResNet18()
    # Initialize parameters
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    # Instantiate optimizer
    tx = optax.sgd(learning_rate, momentum)
    # Instantiate train state
    return TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [12]:

@jax.jit
def train_step(state, batch):
    """Perform a training step on a batch of images and update state."""
    
    def loss_fn(params):
        """Feed forward current batch, returns loss and
        activations of the last layer."""
        
        # Apply (i.e. predict) the model on a batch of images
        y_pred = ResNet18().apply({'params': params}, batch['image'])
        # Compute loss from the predictions
        loss = cross_entropy_loss(y_pred=y_pred, y_label=batch['label'])
        
        return loss, y_pred
    
    # Compute gradient (and value) of loss function
    (_, y_pred), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    # Update model parameters using gradients (!)
    state = state.apply_gradients(grads=grads)
    # Compute metrics for this batch
    metrics = compute_metrics(y_pred=y_pred, y_label=batch['label'])
    
    return state, metrics

In [13]:
@jax.jit
def eval_step(params, batch):
    """Evaluate model on a testing batch."""
    y_pred = ResNet18().apply({'params': params}, batch['image'])
    return compute_metrics(y_pred=y_pred, y_label=batch['label'])

In [14]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    """Train the model for an entire epoch."""
    
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size
    
    # Shuffle the dataset and divide it into mini-batches
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))
    
    # Perform a train step in each batch and save metrics
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)
    
    # Average metrics for each batch
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }
    
    # Retrieve loss and accuracy
    loss = epoch_metrics_np['loss']
    acc = epoch_metrics_np['accuracy'] * 100
    
    return state, loss, acc

In [15]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

In [17]:
# Create random number generator
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

# Create train state
learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)

print(jax.tree_map(lambda x: x.shape, state.params))

FrozenDict({
    Dense_0: {
        bias: (512,),
        kernel: (512, 512),
    },
    ResNetBlock_0: {
        BatchNorm_0: {
            bias: (64,),
            scale: (64,),
        },
        BatchNorm_1: {
            bias: (64,),
            scale: (64,),
        },
        BatchNorm_2: {
            bias: (64,),
            scale: (64,),
        },
        Conv_0: {
            kernel: (3, 3, 1, 64),
        },
        Conv_1: {
            kernel: (3, 3, 64, 64),
        },
        Conv_2: {
            kernel: (1, 1, 1, 64),
        },
    },
    ResNetBlock_1: {
        BatchNorm_0: {
            bias: (64,),
            scale: (64,),
        },
        BatchNorm_1: {
            bias: (64,),
            scale: (64,),
        },
        Conv_0: {
            kernel: (3, 3, 64, 64),
        },
        Conv_1: {
            kernel: (3, 3, 64, 64),
        },
    },
    ResNetBlock_10: {
        BatchNorm_0: {
            bias: (256,),
            scale: (256,),
        },
  

In [18]:
num_epochs = 10
batch_size = 32

for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, loss, acc = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    test_loss, test_acc = eval_model(state.params, test_ds)
    
    print(f"EPOCH {epoch:02} -----------------------------")
    print(f">> Train:  loss={loss:.8f}  acc={acc:.2f}%")
    print(f">>  Test:  loss={test_loss:.8f}  acc={test_acc*100:.2f}%")

ScopeCollectionNotFound: Tried to access "mean" from collection "batch_stats" in "/batchnorm" but the collection is empty. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeCollectionNotFound)