# ResNet Implementation in mlax with Optax optimizers.
This notebook uses the [Optax](https://optax.readthedocs.io/en/latest/optax-101.html) JAX optimization library.

You can view a mixed-precision implementation in `resnet_mixed_precision.ipynb`.

You can view the Pytorch reference implementation in `resnet_reference.ipynb`.

In [1]:
import jax
from jax import (
    numpy as jnp,
    nn,
    random,
    lax
)
import numpy as np
import optax
import torchvision
from torch.utils.data import DataLoader

In [2]:
from mlax import Module
from mlax.nn import (
    Conv, Scaler, ZNorm, Linear, Bias, F, Series, Parallel
)

### Load in and batch the CIFAR-10 datasets.
We follow this example
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) in using Pytorch dataloaders.

In [3]:
class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic)

cifar_train = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform = torchvision.transforms.Compose([
    torchvision.transforms.AutoAugment(),
    ToNumpy()
])
)
cifar_test = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=ToNumpy()
)
print(cifar_train.data.shape)
print(cifar_test.data.shape)

Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3)
(10000, 32, 32, 3)


In [4]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

batch_size=128
train_dataloader = DataLoader(
    cifar_train, batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=6
)
test_dataloader = DataLoader(
    cifar_test, batch_size, collate_fn=numpy_collate, num_workers=6
)
print(len(train_dataloader), len(test_dataloader))

391 79


### Initialize ResNet model parameters.

In [5]:
# 3x3 channel-last conv block with batchnorm
def conv_layers(rng, out_channels, strides):
    keys_iter = iter([random.fold_in(rng, i) for i in range(4)])
    return [
        Conv(next(keys_iter), out_channels, 3, strides, padding=1),
        ZNorm(next(keys_iter), "channel_last"),
        Scaler(next(keys_iter), (0, 0, -1)),
        Bias(next(keys_iter), (0, 0, -1)),
        F(nn.relu)
    ]

# Residual block without downsampling (H, W, C) -> (H, W, C)
class ResBlock1(Module):
    def __init__(self, rng, out_channels):
        super().__init__()
        self.block = Series([
            *conv_layers(random.fold_in(rng, 0), out_channels, strides=1),
            *conv_layers(random.fold_in(rng, 1), out_channels, strides=1)
        ])

    def set_up(self, x):
        pass
    
    def forward(self, x, rng=None, inference_mode=False, batch_axis_name=()):
        acts, self.block = self.block(x, None, inference_mode, batch_axis_name)
        return lax.add(acts, x)

# Residual block with downsampling (H, W, C) -> (H // 2, W // 2 2 * C)
class ResBlock2(Module):
    def __init__(self, rng, out_channels):
        super().__init__()
        self.block = Parallel([
            Series([
                *conv_layers(random.fold_in(rng, 0),out_channels, strides=2),
                *conv_layers(random.fold_in(rng, 1),out_channels, strides=1)
            ]),
            Series(conv_layers(random.fold_in(rng, 2), out_channels, strides=2))
        ])

    def set_up(self, x):
        pass

    def forward(self, x, rng=None, inference_mode=False, batch_axis_name=()):
        acts, self.block = self.block(
            [x, x], None, inference_mode, batch_axis_name
        )
        return lax.add(acts[0], acts[1])

keys_iter = iter([random.fold_in(random.PRNGKey(0), i) for i in range(6)])
model = Series([
    F(lambda x: x.astype(jnp.float32) / 255.0), # To float and scale
    *conv_layers(next(keys_iter), 16, strides=1),
    ResBlock1(next(keys_iter), 16),
    ResBlock2(next(keys_iter), 32),
    ResBlock2(next(keys_iter), 64),
    F(lambda x: jnp.reshape(x.mean((0, 1)), (-1,))), # Avg pool and flatten
    Linear(next(keys_iter), 10),
    Bias(next(keys_iter), 10)
])

# Induce lazy initialization
for X, _ in train_dataloader:
    activations, _ = model(X[0], None, inference_mode=True)
    print(activations.shape)
    print(activations.dtype)
    break

(10,)
float32


### Define loss function.

In [6]:
def loss_fn(batched_preds, batched_targets):
    return optax.softmax_cross_entropy_with_integer_labels(
        batched_preds, batched_targets
    ).mean() # Optax returns per-example loss, this returns the mean batch loss

### Define optimizer using Optax.

In [14]:
optimizer = optax.adam(5e-3)
optim_state = optimizer.init(model.filter())

### Define training and testing steps.

In [15]:
@jax.jit
def train_step(X, y, model, optim_state):
    def _model_loss(X, y, trainables, non_trainables):
        model = trainables.combine(non_trainables)
        preds, model = jax.vmap(
            model.__call__,
            in_axes = (0, None, None, None),
            out_axes = (0, None),
            axis_name = "N"
        )(X, None, False, "N")
        return loss_fn(preds, y), model

    # Find batch loss and gradients with resect to trainables
    trainables, non_trainables = model.partition()
    (loss, model), gradients = jax.value_and_grad(
        _model_loss,
        argnums=2, # gradients wrt trainables (argument 2 of model_training_loss)
        has_aux=True # model is auxiliary data, loss is the true ouput
    )(X, y, trainables, non_trainables)

    # Get new gradients and optimizer state
    gradients, optim_state = optimizer.update(gradients, optim_state)

    # Update parameters with new gradients
    trainables, non_trainables = model.partition()
    trainables = optax.apply_updates(gradients, trainables)
    return loss, trainables.combine(non_trainables), optim_state

In [16]:
@jax.jit
def test_step(X, y, model):
    preds, _ = jax.vmap(
        model.__call__,
        in_axes = (0, None, None, None),
        out_axes = (0, None),
        axis_name = "N"
    )(X, None, True, "N")
    accurate = (jnp.argmax(preds, axis=1) == y).sum()
    return loss_fn(preds, y), accurate

### Define training and testing loops.

In [17]:
def train_epoch(dataloader, model, optim_state):
    train_loss = 0.0
    for X, y in dataloader:
        loss, model, optim_state = train_step(X, y, model, optim_state)
        train_loss += loss

    print(f"Train loss: {train_loss / len(dataloader)}") 
    return model, optim_state

In [18]:
def test(dataloader, model):
    test_loss, accurate = 0.0, 0
    for X, y in dataloader:
        loss, acc = test_step(X, y, model)
        test_loss += loss
        accurate += acc
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [19]:
def train_loop(
    train_dataloader,
    test_dataloader,
    model,
    optim_state,
    epochs,
    test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model, optim_state = train_epoch(train_dataloader, model, optim_state)
        if (epoch % test_every == 0):
            test(test_dataloader, model)
        print(f"----------------")
    return model, optim_state

### Train ResNet on the CIFAR-10 dataset.

In [20]:
with jax.default_matmul_precision("float32"):
    new_model, new_optim_state = train_loop(
        train_dataloader,
        test_dataloader,
        model,
        optim_state,
        40, 5
    )

Epoch 1
----------------
Train loss: 1.6988275051116943
----------------
Epoch 2
----------------
Train loss: 1.3090256452560425
----------------
Epoch 3
----------------
Train loss: 1.1465835571289062
----------------
Epoch 4
----------------
Train loss: 1.0521374940872192
----------------
Epoch 5
----------------
Train loss: 0.9863284230232239
Test loss: 0.8191496729850769, accuracy: 0.7147000432014465
----------------
Epoch 6
----------------
Train loss: 0.9384129643440247
----------------
Epoch 7
----------------
Train loss: 0.8958556652069092
----------------
Epoch 8
----------------
Train loss: 0.8691748976707458
----------------
Epoch 9
----------------
Train loss: 0.8439715504646301
----------------
Epoch 10
----------------
Train loss: 0.8208428621292114
Test loss: 0.711186408996582, accuracy: 0.7502000331878662
----------------
Epoch 11
----------------
Train loss: 0.7981186509132385
----------------
Epoch 12
----------------
Train loss: 0.7758307456970215
----------------
Ep