# ResNet Implementation in mlax with Optax optimizers.
This notebook uses the [Optax](https://optax.readthedocs.io/en/latest/optax-101.html) JAX optimization library.

You can view the Pytorch reference implementation in `resnet_reference.ipynb`.

In [1]:
import jax
import jax.numpy as jnp
from jax import (
    nn,
    random,
    tree_util
)
import numpy as np
import optax
import torchvision
from torch.utils.data import DataLoader

In [2]:
from mlax.nn import Conv, Scaler, BatchNorm, Linear, Bias, F
from mlax.block import Series, Parallel

### Load in and batch the CIFAR-10 datasets.
We follow this example
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) in using Pytorch dataloaders.

In [3]:
class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic)

cifar_train = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform = torchvision.transforms.Compose([
    torchvision.transforms.AutoAugment(),
    ToNumpy()
])
)
cifar_test = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=ToNumpy()
)
print(cifar_train.data.shape)
print(cifar_test.data.shape)

Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3)
(10000, 32, 32, 3)


In [4]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

batch_size=128
train_dataloader = DataLoader(
    cifar_train, batch_size=128, shuffle=True, collate_fn=numpy_collate, num_workers=8
)
test_dataloader = DataLoader(
    cifar_test, batch_size=128, shuffle=True, collate_fn=numpy_collate, num_workers=8
)
print(len(train_dataloader), len(test_dataloader))

391 79


### Initialize ResNet model parameters.

In [5]:
def bypass(x):
    return x

def split(x):
    return x, x

def add(pair):
    x, y = pair
    return jax.lax.add(x, y)

def res_block1_init(key, size):
    keys_iter = iter(random.split(key, 8))
    return Series.init(
        F.init(split),
        Parallel.init(
            F.init(bypass),
            Series.init(
                # Conv with batchnorm
                Conv.init(next(keys_iter), 2, size, size, 3, padding=1, channel_last=True),
                BatchNorm.init(next(keys_iter), size, channel_axis=-1),
                Scaler.init(next(keys_iter), (None, None, size)),
                Bias.init(next(keys_iter), (None, None, size)),
                F.init(nn.relu),
                # Conv with batchnorm
                Conv.init(next(keys_iter), 2, size, size, 3, padding=1, channel_last=True),
                BatchNorm.init(next(keys_iter), size, channel_axis=-1),
                Scaler.init(next(keys_iter), (None, None, size)),
                Bias.init(next(keys_iter), (None, None, size)),
                F.init(nn.relu),
            )
        ),
        F.init(add)
    )

def res_block2_init(key, size):
    keys_iter = iter(random.split(key, 12))
    size2 = size * 2
    return Series.init(
        F.init(split),
        Parallel.init(
            # Downsampling conv with batchnorm
            Series.init(
                Conv.init(
                    next(keys_iter), 2, size, size2, 3, strides=2, padding=1, channel_last=True
                ),
                BatchNorm.init(next(keys_iter), size2, channel_axis=-1),
                Scaler.init(next(keys_iter), (None, None, size2)),
                Bias.init(next(keys_iter), (None, None, size2)),
                F.init(nn.relu),
            ),
            Series.init(
                # Downsampling conv with batchnorm
                Conv.init(
                    next(keys_iter), 2, size, size2, 3, strides=2, padding=1, channel_last=True
                ),
                BatchNorm.init(next(keys_iter), size2, channel_axis=-1),
                Scaler.init(next(keys_iter), (None, None, size2)),
                Bias.init(next(keys_iter), (None, None, size2)),
                F.init(nn.relu),
                # Conv with batchnorm
                Conv.init(next(keys_iter), 2, size2, size2, 3, padding=1, channel_last=True),
                BatchNorm.init(next(keys_iter), size2, channel_axis=-1),
                Scaler.init(next(keys_iter), (None, None, size2)),
                Bias.init(next(keys_iter), (None, None, size2)),
                F.init(nn.relu),
            )
        ),
        F.init(add)
    )

def model_init(key):
    keys_iter = iter(random.split(key, 9))
    return Series.init(
        # Convert int8 numpy inputs to float32 JAX arrays
        F.init(
            lambda x: jnp.asarray(x, jnp.float32) / 256.0,    
        ),
        # (N, 32, 32, 3)
        Conv.init(next(keys_iter), 2, 3, 16, 3, padding=1, channel_last=True),
        BatchNorm.init(next(keys_iter), 16, channel_axis=-1),
        Scaler.init(next(keys_iter), (None, None, 16)),
        Bias.init(next(keys_iter), (None, None, 16)),
        F.init(nn.relu),
        # (N, 32, 32, 16)
        res_block1_init(next(keys_iter), 16),
        # (N, 32, 32, 16)
        res_block2_init(next(keys_iter), 16),
        # (N, 16, 16, 32)
        res_block2_init(next(keys_iter), 32),
        # (N, 8, 8, 64)
        F.init(lambda x: jnp.reshape(x.mean((1, 2)), (-1, 64))),
        # (N, 64)
        Linear.init(next(keys_iter), 64, 10),
        # (N, 10)
        Bias.init(next(keys_iter), (10,))
    )

In [6]:
trainables, non_trainables, hyperparams = model_init(random.PRNGKey(0))

### Define ResNet dataflow.

In [7]:
model_fwd = jax.jit(
    Series.fwd,
    static_argnames=["hyperparams", "inference_mode"]
)

### Define loss function.

In [8]:
@jax.jit
def loss_fn(
    batched_preds: jnp.array,
    batched_targets: np.array
):
    return optax.softmax_cross_entropy_with_integer_labels(
        batched_preds,
        batched_targets
    ).mean() # Optax returns per-example loss, this returns the mean batch loss

We define a convenience function that get model predictions on batched inputs,
and calculates the loss against batched targets.

In [9]:
def model_loss(
    x_batch: np.array,
    y_batch: np.array,
    trainables,
    non_trainables,
    hyperparams,
):
    preds, non_trainables = model_fwd(
        x_batch, trainables, non_trainables, hyperparams
    )
    return loss_fn(preds, y_batch), non_trainables

### Define optimizer using Optax.

In [10]:
optimizer = optax.adam(1e-2)
optim_state = optimizer.init(trainables)
optim_fn = tree_util.Partial(optimizer.update)

### Define training step.

In [11]:
@tree_util.Partial(jax.jit, static_argnames="hyperparams")
def train_step(
    x_batch: np.array, 
    y_batch: np.array,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state
):
    # Find batch loss and gradients
    (loss, non_trainables), gradients = jax.value_and_grad(
        model_loss,
        argnums=2, # gradients wrt trainables (argument 2 of model_loss)
        has_aux=True # non_trainables is auxiliary data, loss is the true ouput
    )(x_batch, y_batch, trainables, non_trainables, hyperparams)

    # Get new gradients and optimizer state
    gradients, optim_state = optim_fn(gradients, optim_state)

    # Update model_weights with new gradients
    trainables = optax.apply_updates(gradients, trainables)
    return loss, trainables, non_trainables, optim_state

### Define functions for training and testing loops.

In [12]:
def train_epoch(
    dataloader,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state
):
    train_loss = 0.0
    for X, y in dataloader:
        loss, trainables, non_trainables, optim_state = train_step(
            X, y,
            trainables, non_trainables, hyperparams,
            optim_fn, optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / len(dataloader)}") 
    return trainables, non_trainables, optim_state

In [13]:
def test(
    dataloader,
    trainables, non_trainables, hyperparams
):
    test_loss, accuracy = 0, 0.0
    for X, y in dataloader:
        preds, _ = model_fwd(
            X, trainables, non_trainables, hyperparams
        )
        loss = loss_fn(preds, y)
        test_loss += loss
        accuracy += (jnp.argmax(preds, axis=1) == y).sum()
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accuracy / len(dataloader.dataset)}")

In [14]:
def train_loop(
    train_dataloader, test_dataloader, 
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        trainables, non_trainables, optim_state = train_epoch(
            train_dataloader,
            trainables, non_trainables, hyperparams,
            optim_fn, optim_state
        )
        if (epoch % test_every == 0):
            test(test_dataloader, trainables, non_trainables, hyperparams)
        print(f"----------------")
    
    return trainables, non_trainables, optim_state

### Train ResNet on the CIFAR-10 dataset.

In [15]:
new_trainables, new_non_trainables, new_optim_state = train_loop(
    train_dataloader, test_dataloader,
    trainables, non_trainables, hyperparams,
    optim_fn, optim_state,
    40, 5
)

Epoch 1
----------------
Train loss: 1.821688175201416
----------------
Epoch 2
----------------
Train loss: 1.3466664552688599
----------------
Epoch 3
----------------
Train loss: 1.1618753671646118
----------------
Epoch 4
----------------
Train loss: 1.042677879333496
----------------
Epoch 5
----------------
Train loss: 0.9711943864822388
Test loss: 0.7877271175384521, accuracy: 0.7255000472068787
----------------
Epoch 6
----------------
Train loss: 0.9213968515396118
----------------
Epoch 7
----------------
Train loss: 0.8760479688644409
----------------
Epoch 8
----------------
Train loss: 0.8446820378303528
----------------
Epoch 9
----------------
Train loss: 0.8178773522377014
----------------
Epoch 10
----------------
Train loss: 0.7975940108299255
Test loss: 0.6389548778533936, accuracy: 0.7777000665664673
----------------
Epoch 11
----------------
Train loss: 0.7704460024833679
----------------
Epoch 12
----------------
Train loss: 0.7479048371315002
----------------
Epo