# ResNet Implementation in mlax with Optax optimizers.
This notebook uses the [Optax](https://optax.readthedocs.io/en/latest/optax-101.html) JAX optimization library.

You can view a full-precision implementation in `resnet_mixed_precision.ipynb`.

You can view the full-precision Pytorch reference implementation in
`resnet_reference.ipynb`.

In [1]:
import jax
import jax.numpy as jnp
from jax import (
    nn,
    random,
    tree_util
)
import numpy as np
import optax
import torchvision
from torch.utils.data import DataLoader

In [2]:
from mlax.nn import Conv, Scaler, BatchNorm, Linear, Bias, F
from mlax.block import Series, Parallel
from mlax.functional import identity

### Load in and batch the CIFAR-10 datasets.
We follow this example
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) in using Pytorch dataloaders.

In [3]:
class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic)

cifar_train = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform = torchvision.transforms.Compose([
    torchvision.transforms.AutoAugment(),
    ToNumpy()
])
)
cifar_test = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=ToNumpy()
)
print(cifar_train.data.shape)
print(cifar_test.data.shape)

Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3)
(10000, 32, 32, 3)


In [4]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

batch_size=128
train_dataloader = DataLoader(
    cifar_train, batch_size=128, shuffle=True, collate_fn=numpy_collate, num_workers=8
)
test_dataloader = DataLoader(
    cifar_test, batch_size=128, shuffle=True, collate_fn=numpy_collate, num_workers=8
)
print(len(train_dataloader), len(test_dataloader))

391 79


### Initialize ResNet model parameters.

In [5]:
def toHalfPrecision(x):
    return x.astype(jnp.float16)

def toFullPrecision(x):
    return x.astype(jnp.float32)

# 2D Conv with 3x3 filter with batch norm and ReLu
def conv2d_block_init(key, in_channels, out_channels, strides=1):
    key1, key2, key3, key4 = random.split(key, 4)
    return Series.init(
        Conv.init(
            key1, 2,
            in_channels, out_channels, 3,
            strides=strides,
            padding=1,
            channel_last=True
        ),
        F.init(toFullPrecision),
        BatchNorm.init(key2, out_channels, channel_last=True),
        F.init(toHalfPrecision),
        Scaler.init(key3, (None, None, out_channels)),
        Bias.init(key4, (None, None, out_channels)),
        F.init(nn.relu)
    )

def fan_out(x):
    return x, x

def fan_in_add(pair):
    x, y = pair
    return jax.lax.add(x, y)

# ResBlock keeping the image width, height and the number of channels the same
def res_block1_init(key, in_channels):
    key1, key2 = random.split(key)
    return Series.init(
        F.init(fan_out),
        Parallel.init(
            F.init(identity),
            Series.init(
                conv2d_block_init(key1, in_channels, in_channels),
                conv2d_block_init(key2, in_channels, in_channels)
            )
        ),
        F.init(fan_in_add)
    )

# ResBlock halving image width and height and doubling the number of channels
def res_block2_init(key, in_channels):
    key1, key2, key3 = random.split(key, 3)
    out_channels = in_channels * 2
    return Series.init(
        F.init(fan_out),
        Parallel.init(
            conv2d_block_init(key1, in_channels, out_channels, strides=2),
            Series.init(
                conv2d_block_init(key2, in_channels, out_channels, strides=2),
                conv2d_block_init(key3, out_channels, out_channels)
            )
        ),
        F.init(fan_in_add)
    )

def model_init(key):
    keys_iter = iter(random.split(key, 6))
    return Series.init(
        # Convert int8 numpy inputs to float32 JAX arrays
        F.init(
            lambda x: jnp.asarray(x, jnp.float16) / 255.0,   
        ),
        # (N, 32, 32, 3)
        conv2d_block_init(next(keys_iter), 3, 16),
        # (N, 32, 32, 16)
        res_block1_init(next(keys_iter), 16),
        # (N, 32, 32, 16)
        res_block2_init(next(keys_iter), 16),
        # (N, 16, 16, 32)
        res_block2_init(next(keys_iter), 32),
        # (N, 8, 8, 64)
        F.init(lambda x: jnp.reshape(x.mean((1, 2)), (-1, 64))),
        # (N, 64)
        Linear.init(next(keys_iter), 64, 10),
        # (N, 10)
        Bias.init(next(keys_iter), (10,)),
        F.init(toFullPrecision)
    )

In [6]:
trainables, non_trainables, hyperparams = model_init(random.PRNGKey(0))

### Define ResNet dataflow.

In [7]:
model_fwd = Series.fwd

### Define loss function.

In [8]:
def loss_fn(
    batched_preds: jnp.array,
    batched_targets: np.array
):
    return optax.softmax_cross_entropy_with_integer_labels(
        batched_preds,
        batched_targets
    ).mean() # Optax returns per-example loss, this returns the mean batch loss

We define two convenience functions that repectively find the training loss and
inference predictions and loss from batched inputs and targets.

In [9]:
scaling_factor = 2 ** 12

def model_train_loss(
    x_batch: np.array,
    y_batch: np.array,
    trainables,
    non_trainables,
    hyperparams
):
    preds, non_trainables = model_fwd(
        x_batch, trainables, non_trainables, hyperparams, inference_mode=False
    )
    return loss_fn(preds, y_batch) * scaling_factor, non_trainables

@tree_util.Partial(jax.jit, static_argnames="hyperparams")
def model_inference_preds_loss(
    x_batch: np.array,
    y_batch: np.array,
    trainables,
    non_trainables,
    hyperparams
):
    preds, _ = model_fwd(
        x_batch, trainables, non_trainables, hyperparams, inference_mode=True
    )
    return preds, loss_fn(preds, y_batch)

### Define optimizer using Optax.

In [10]:
optimizer = optax.adam(1e-2)
optim_state = optimizer.init(trainables)
optim_fn = tree_util.Partial(optimizer.update)

### Define training step.

In [11]:
@tree_util.Partial(jax.jit, static_argnames="hyperparams")
def train_step(
    x_batch: np.array, 
    y_batch: np.array,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state
):
    # Find batch loss and gradients
    (loss, non_trainables), gradients = jax.value_and_grad(
        model_train_loss,
        argnums=2, # gradients wrt trainables (argument 2 of model_loss)
        has_aux=True # non_trainables is auxiliary data, loss is the true ouput
    )(x_batch, y_batch, trainables, non_trainables, hyperparams)

    # Loss unscaling
    loss = loss / scaling_factor
    def unscale_gradients(x):
        return x / scaling_factor
    gradients = tree_util.tree_map(unscale_gradients, gradients)

    # Get new gradients and optimizer state
    gradients, optim_state = optim_fn(gradients, optim_state)

    # Update model_weights with new gradients
    trainables = optax.apply_updates(gradients, trainables)
    return loss, trainables, non_trainables, optim_state

### Define functions for training and testing loops.

In [12]:
def train_epoch(
    dataloader,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state
):
    train_loss = 0.0
    for X, y in dataloader:
        loss, trainables, non_trainables, optim_state = train_step(
            X, y,
            trainables, non_trainables, hyperparams,
            optim_fn, optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / len(dataloader)}") 
    return trainables, non_trainables, optim_state

In [13]:
def test(
    dataloader,
    trainables, non_trainables, hyperparams
):
    test_loss, accuracy = 0, 0.0
    for X, y in dataloader:
        preds, loss = model_inference_preds_loss(
            X, y, trainables, non_trainables, hyperparams
        )
        test_loss += loss
        accuracy += (jnp.argmax(preds, axis=1) == y).sum()
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accuracy / len(dataloader.dataset)}")

In [14]:
def train_loop(
    train_dataloader, test_dataloader, 
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        trainables, non_trainables, optim_state = train_epoch(
            train_dataloader,
            trainables, non_trainables, hyperparams,
            optim_fn, optim_state
        )
        if (epoch % test_every == 0):
            test(test_dataloader, trainables, non_trainables, hyperparams)
        print(f"----------------")
    
    return trainables, non_trainables, optim_state

### Train ResNet on the CIFAR-10 dataset.

In [15]:
new_trainables, new_non_trainables, new_optim_state = train_loop(
    train_dataloader, test_dataloader,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state,
    40, 5
)

Epoch 1
----------------
Train loss: 1.865752100944519
----------------
Epoch 2
----------------
Train loss: 1.4090983867645264
----------------
Epoch 3
----------------
Train loss: 1.216686725616455
----------------
Epoch 4
----------------
Train loss: 1.0826411247253418
----------------
Epoch 5
----------------
Train loss: 0.9965898394584656
Test loss: 0.8694150447845459, accuracy: 0.6937000155448914
----------------
Epoch 6
----------------
Train loss: 0.950756311416626
----------------
Epoch 7
----------------
Train loss: 0.9044058322906494
----------------
Epoch 8
----------------
Train loss: 0.8645985722541809
----------------
Epoch 9
----------------
Train loss: 0.8370248079299927
----------------
Epoch 10
----------------
Train loss: 0.8095261454582214
Test loss: 0.6727048754692078, accuracy: 0.7669000625610352
----------------
Epoch 11
----------------
Train loss: 0.7876483201980591
----------------
Epoch 12
----------------
Train loss: 0.7697210907936096
----------------
Epoc