# JAX & FLAX

> **Author**: Gustavo L. F. Walbon / **Date**: July 2022.

In this work we will have the inference of resnet18 with Jax engines.

## Objective
_In the fourth project you will have to implement the ResNet-18 model using JAX and Flax libraries for inference only. The FLAX library, discussed in the second lecture, provides primitives to stack multiple kinds of layers in order to form a neural network architecture._

_After you have implemented the model and loaded the weights, it is time to test your code on a few test images. You may use the images provided in [5](https://github.com/MO436-MC934/work). Finally, you job is to obtain the maximum amount of performance improvement from your network using JAX transformations: jit, vmap, pmap, etc. You may use your CPU, GPU or TPU (in Google Collab). In the end, you should write a report (PDF) describing how you implemented your model, which transformations you applied and why. If you could not apply some transformations, discuss the problems you found while trying to use it. Finally, do a performance analysis showing how fast your model has become compared to the non-transformed model (tables and graphs are welcome)._

_Reference: https://github.com/MO436-MC934/notebooks/wiki/5.JAX-Library_

### Loading Resnet18
Pre-trained network is used to see the weights of kernel and bias of Resnet.

In [1]:
import onnx
from onnx import numpy_helper

model = onnx.load("resnet18.onnx")

for initializer in model.graph.initializer:
    array = numpy_helper.to_array(initializer)
    print(f"- Tensor: {initializer.name!r:45} shape={array.shape}")

- Tensor: 'resnetv15_conv0_weight'                      shape=(64, 3, 7, 7)
- Tensor: 'resnetv15_batchnorm0_gamma'                  shape=(64,)
- Tensor: 'resnetv15_batchnorm0_beta'                   shape=(64,)
- Tensor: 'resnetv15_batchnorm0_running_mean'           shape=(64,)
- Tensor: 'resnetv15_batchnorm0_running_var'            shape=(64,)
- Tensor: 'resnetv15_stage1_conv0_weight'               shape=(64, 64, 3, 3)
- Tensor: 'resnetv15_stage1_batchnorm0_gamma'           shape=(64,)
- Tensor: 'resnetv15_stage1_batchnorm0_beta'            shape=(64,)
- Tensor: 'resnetv15_stage1_batchnorm0_running_mean'    shape=(64,)
- Tensor: 'resnetv15_stage1_batchnorm0_running_var'     shape=(64,)
- Tensor: 'resnetv15_stage1_conv1_weight'               shape=(64, 64, 3, 3)
- Tensor: 'resnetv15_stage1_batchnorm1_gamma'           shape=(64,)
- Tensor: 'resnetv15_stage1_batchnorm1_beta'            shape=(64,)
- Tensor: 'resnetv15_stage1_batchnorm1_running_mean'    shape=(64,)
- Tensor: 'resnetv15_s

With that the next step is to create the dicionary with the same shape to be used in the interference.

### Creating de CNN

Resnet models has the format NCHW(Batch size, Channels, Height, Width), so that means the eg. (64, 3, 7, 7) refers to N=64, Channels=3, Height=7 and Width=7.

In [17]:
import flax.linen as nn
import optax

# We need this to hold the training state
from flax.training.train_state import TrainState


class ResNet18(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        conv = partial(self.conv, use_bias=False, dtype=self.dtype)
        norm = partial(nn.BatchNorm,
                       use_running_average=not train, 
                       momentum=0.9,
                       epsilon=1e-5,
                       dtype=self.dtype)
        x = conv(self.num_filters, (7,7), (2,2),
                 padding=[(3,3), (3,3)],
                 name='conv_init')(x)
        x = norm(name='bn_init')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3,3), strides=(2,2), padding='SAME')
        for i, block_size in enumerate(self.stage_size):
            for j in range(block_size):
                strides = (2,2) if i > 0 and j == 0 else (1,1)
                x = self.block_cls(self.num_filters * 2**i,
                                   strides = strides,
                                   conv = conv,
                                   norm = norm,
                                   act  = self.act)(x)
        x = jnp.mean(x, axis=(1,2))
        x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
        x = jnp.asarray(x, self.dtype)
        return x

The function `__call__` was overwritten to be used internally by the Module.

### Evaluation Function
Next we define a function to evaluate the entire model and summarize the metrics.

In [18]:
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']
def eval_model(params, test_ds):
    metrics = eval_step(params, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

In [19]:
import jax.numpy as jnp
import jax.scipy as jsp
import jax
import jax.numpy as jnp
from functools import partial

In [20]:
def create_train_state(rng, learning_rate, momentum):
    # Instantiate model
    cnn = ResNet18()
    # Initialize parameters
    params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
    # Instantiate optimizer
    tx = optax.sgd(learning_rate, momentum)
    # Instantiate train state
    return TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [21]:
# Create random number generator
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

# Create train state
learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)

print(jax.tree_map(lambda x: x.shape, state.params))

AttributeError: "ResNet18" object has no attribute "conv"