# ResNet Implementation in mlax with Optax optimizers.
This notebook uses the [Optax](https://optax.readthedocs.io/en/latest/optax-101.html) JAX optimization library.

You can view a full-precision implementation in `resnet.ipynb`.

You can view the Pytorch reference implementation in `resnet_reference.ipynb`.

In [14]:
import jax
from jax import (
    numpy as jnp,
    tree_util as jtu,
    nn,
    random,
    lax
)
import numpy as np
from functools import partial
import optax
import torchvision
from torch.utils.data import DataLoader

In [15]:
from mlax import Module, is_trainable
from mlax.nn import (
    Conv, Scaler, BatchNorm, Linear, Bias, F, Series
)

### Load in and batch the CIFAR-10 datasets.
We follow this example
[Training a Simple Neural Network, with PyTorch Data Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) in using Pytorch dataloaders.

In [16]:
class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic)

cifar_train = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform = torchvision.transforms.Compose([
    torchvision.transforms.AutoAugment(),
    ToNumpy()
])
)
cifar_test = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=ToNumpy()
)
print(cifar_train.data.shape)
print(cifar_test.data.shape)

Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3)
(10000, 32, 32, 3)


In [17]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

batch_size=128
train_dataloader = DataLoader(
    cifar_train, batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=6
)
test_dataloader = DataLoader(
    cifar_test, batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=6
)
print(len(train_dataloader), len(test_dataloader))

391 79


### Initialize ResNet model parameters.

In [18]:
# 3x3 channel-last conv block with batchnorm
class ConvBlock(Module):
    def __init__(self, rng, out_channels, strides=1, batch_axis_name="batch"):
        super().__init__()
        rngs_iter = iter(random.split(rng, 4))
        self.conv = Conv(
            next(rngs_iter), 2, out_channels, 3, strides, padding=1, channel_last=True
        )
        self.batchnorm = Series([
            BatchNorm(next(rngs_iter), batch_axis_name, channel_last=True),
            Scaler(next(rngs_iter), (None, None, out_channels)),
            Bias(next(rngs_iter), (None, None, out_channels)),
        ])
    
    def __call__(self, x, rng=None, inference_mode=False):
        x, self.conv = self.conv(x, None, inference_mode)
        x_dtype = x.dtype
        x = x.astype(jnp.float32) # Keep batchnorm in full precision
        x, self.batchnorm = self.batchnorm(x, None, inference_mode)
        x = x.astype(x_dtype)
        return nn.relu(x), self

# Residual block without downsampling (H, W, C) -> (H, W, C)
class ResBlock1(Module):
    def __init__(self, rng, out_channels):
        super().__init__()
        rng1, rng2 = random.split(rng)
        self.block = Series([
            ConvBlock(rng1, out_channels),
            ConvBlock(rng2, out_channels)
        ])
    
    def __call__(self, x, rng=None, inference_mode=False):
        acts, self.block = self.block(x, None, inference_mode)
        return lax.add(acts, x), self

# Residual block with downsampling (H, W, C) -> (H // 2, W // 2 2 * C)
class ResBlock2(Module):
    def __init__(self, rng, out_channels):
        super().__init__()
        rng1, rng2, rng3 = random.split(rng, 3)
        self.block = Series([
            ConvBlock(rng1, out_channels, strides=2),
            ConvBlock(rng2, out_channels)
        ])
        self.downsample = ConvBlock(rng3, out_channels, strides=2)
    
    def __call__(self, x, rng=None, inference_mode=False):
        acts, self.block = self.block(x, None, inference_mode)
        x, self.downsample = self.downsample(x, None, inference_mode)
        return lax.add(acts, x), self

class ResNet(Module):
    def __init__(self, rng):
        super().__init__()
        rngs_iter = iter(random.split(rng, 6))
        self.conv = ConvBlock(next(rngs_iter), 16)
        self.res1 = ResBlock1(next(rngs_iter), 16)
        self.res2 = ResBlock2(next(rngs_iter), 32)
        self.res3 = ResBlock2(next(rngs_iter), 64)
        self.fc = Series([
            Linear(next(rngs_iter), 10),
            Bias(next(rngs_iter), (10,))
        ])

    @partial(
        jax.vmap,
        in_axes = (None, 0, None, None),
        out_axes = (0, None),
        axis_name = "batch"
    ) # Add leading batch dimension
    def __call__(self, x, rng=None, inference_mode=False):
        x = x.astype(jnp.float16) / 255.0
        # (32, 32, 3)
        x, self.conv = self.conv(x, None, inference_mode)
        # (32, 32, 16)
        x, self.res1 = self.res1(x, None, inference_mode)
        # (32, 32, 16)
        x, self.res2 = self.res2(x, None, inference_mode)
        # (16, 16, 32)
        x, self.res3 = self.res3(x, None, inference_mode)
        # (8, 8, 64)
        x = jnp.reshape(x.mean((0, 1)), (-1,))
        # (64,)
        x, self.fc = self.fc(x, None, inference_mode)
        # (10,)
        return x, self

model = ResNet(random.PRNGKey(0))

# Induce lazy weight initialization
for x_batch, y_batch in train_dataloader:
    acts, model = model(x_batch, None, False)
    print(acts.shape)
    break

(128, 10)


### Define loss function.

In [19]:
def loss_fn(
    batched_preds: jnp.array,
    batched_targets: np.array
):
    return optax.softmax_cross_entropy_with_integer_labels(
        batched_preds,
        batched_targets
    ).mean() # Optax returns per-example loss, this returns the mean batch loss

We define two convenience functions.

``model_training_loss`` returns the batch loss and updated `model` from batched
inputs and targets.

``model_inference_preds_loss`` returns the batch loss and predictions from
batched inputs and targets.

In [20]:
def model_training_loss(
    x_batch: np.array,
    y_batch: np.array,
    trainables,
    non_trainables,
    scaling_factor = 1
):
    model = trainables.combine(non_trainables)
    preds, model = model(x_batch, None, False)
    return loss_fn(preds, y_batch) * scaling_factor, model

@jax.jit
def model_inference_loss(
    x_batch: np.array,
    y_batch: np.array,
    model
):
    preds, _ = model(x_batch, None, True)
    return loss_fn(preds, y_batch), preds

### Define optimizer using Optax.

In [21]:
optimizer = optax.adam(1e-2)
optim_state = optimizer.init(model.filter(is_trainable))

### Define training step.

In [22]:
@jax.jit
def train_step(
    x_batch: np.array, 
    y_batch: np.array,
    model,
    optim_state
):
    scaling_factor = 2 ** 12

    # Find batch loss and gradients with resect to trainables
    (loss, model), gradients = jax.value_and_grad(
        model_training_loss,
        argnums=2, # gradients wrt trainables (argument 2 of model_training_loss)
        has_aux=True # model is auxiliary data, loss is the true ouput
    )(x_batch, y_batch, *model.partition(), scaling_factor)
    
    # Loss unscaling
    loss = loss / scaling_factor
    def unscale_gradients(x):
        return x / scaling_factor
    gradients = jtu.tree_map(unscale_gradients, gradients)

    # Get new gradients and optimizer state
    gradients, optim_state = optimizer.update(gradients, optim_state)

    # Update model_weights with new gradients
    trainables, non_trainables = model.partition()
    trainables = optax.apply_updates(gradients, trainables)
    return loss, trainables.combine(non_trainables), optim_state

### Define functions for training and testing loops.

In [23]:
def train_epoch(
    dataloader,
    model,
    optim_state
):
    num_batches = len(dataloader)
    train_loss = 0.0
    for x_batch, y_batch in dataloader:
        loss, model, optim_state = train_step(
            x_batch, y_batch,
            model,
            optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / num_batches}") 
    return model, optim_state

In [24]:
def test(
    dataloader,
    model
):
    num_batches = len(dataloader)
    test_loss, accuracy = 0, 0.0
    for x_batch, y_batch in dataloader:
        loss, preds = model_inference_loss(
            x_batch, y_batch, model
        )
        test_loss += loss
        accuracy += (jnp.argmax(preds, axis=1) == y_batch).sum() / len(x_batch)
    
    print(f"Test loss: {test_loss / num_batches}, accuracy: {accuracy / num_batches}")

In [25]:
def train_loop(
    train_dataloader,
    test_dataloader,
    model,
    optim_state,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model, optim_state = train_epoch(
            train_dataloader,
            model,
            optim_state
        )
        if (epoch % test_every == 0):
            test(test_dataloader, model)
        print(f"----------------")
    
    return model, optim_state

### Train ResNet on the CIFAR-10 dataset.

In [26]:
new_model, new_optim_state = train_loop(
    train_dataloader,
    test_dataloader,
    model,
    optim_state,
    40, 5
)

Epoch 1
----------------
Train loss: 1.8193359375
----------------
Epoch 2
----------------
Train loss: 1.3955078125
----------------
Epoch 3
----------------
Train loss: 1.19140625
----------------
Epoch 4
----------------
Train loss: 1.0732421875
----------------
Epoch 5
----------------
Train loss: 0.99755859375
Test loss: 0.78076171875, accuracy: 0.7212223410606384
----------------
Epoch 6
----------------
Train loss: 0.93603515625
----------------
Epoch 7
----------------
Train loss: 0.89111328125
----------------
Epoch 8
----------------
Train loss: 0.86181640625
----------------
Epoch 9
----------------
Train loss: 0.82177734375
----------------
Epoch 10
----------------
Train loss: 0.79345703125
Test loss: 0.896484375, accuracy: 0.6991693377494812
----------------
Epoch 11
----------------
Train loss: 0.78564453125
----------------
Epoch 12
----------------
Train loss: 0.7607421875
----------------
Epoch 13
----------------
Train loss: 0.7421875
----------------
Epoch 14
------