<a href="https://colab.research.google.com/github/sayakpaul/MLP-Mixer-CIFAR10/blob/main/MLP_Mixer_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Note**: Some portions of the code present in this notebook are referred from this tutorial: [Image Classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/). 

## Setup

In [1]:
# !pip install -q tf-models-official tensorflow-probability

In [2]:
from official.vision.image_classification.augment import RandAugment
from tensorflow.keras import layers
from tensorflow import keras
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

The dtype policy mixed_float16 may run slowly because this machine does not have a GPU. Only Nvidia GPUs with compute capability of at least 7.0 run quickly with mixed_float16.


In [3]:
try: 
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: 
    strategy = tf.distribute.MirroredStrategy() 

print("Number of accelerators: ", strategy.num_replicas_in_sync)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of accelerators:  1


## Hyperparameters

In [4]:
RESIZE_TO = 96
CROP_TO = 72
PATCH_SIZE = 9

NUM_MIXER_LAYERS = 6
HIDDEN_SIZE = 128
MLP_SEQ_DIM = 64
MLP_CHANNEL_DIM = 128

EPOCHS = 100
BATCH_SIZE = 256 * strategy.num_replicas_in_sync

## Dataset Utilities

In [5]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

val_samples = 49500
new_train_x, new_y_train = x_train[:val_samples], y_train[:val_samples]
val_x, val_y = x_train[val_samples:], y_train[val_samples:]

In [6]:
augmenter = RandAugment(num_layers=2, magnitude=9)

def preprocess_train(image, label):
    image = tf.cast(image, tf.float32)
    label = tf.cast(label, tf.float32)

    image = tf.image.random_flip_left_right(image)
    image = tf.image.resize(image, [RESIZE_TO, RESIZE_TO])
    image = tf.image.random_crop(image, [CROP_TO, CROP_TO, 3])
    image = augmenter.distort(image)
    return image, label

def preprocess_test(image, label):
    image = tf.cast(image, tf.float32)
    label = tf.cast(label, tf.float32)

    image = tf.image.resize(image, [CROP_TO, CROP_TO])
    image = tf.image.central_crop(image, 0.5)
    return image, label

In [7]:
AUTO = tf.data.AUTOTUNE

train_ds = tf.data.Dataset.from_tensor_slices((new_train_x, new_y_train))
validation_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y))
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))

train_ds = (train_ds
                  .shuffle(BATCH_SIZE*10, seed=42)
                  .map(preprocess_train, num_parallel_calls=AUTO)
                  .batch(BATCH_SIZE)
                  .prefetch(AUTO))

validation_ds = (validation_ds.map(preprocess_test, num_parallel_calls=AUTO)
                  .batch(BATCH_SIZE)
                  .prefetch(AUTO))

test_ds = (test_ds.map(preprocess_test, num_parallel_calls=AUTO)
                  .batch(BATCH_SIZE)
                  .prefetch(AUTO))

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


## MLP-Mixer Utilities

This is referred from **Appendix E** of the [original paper](https://arxiv.org/pdf/2105.01601.pdf). 

In [8]:
def mlp_block(x, mlp_dim):
    x = layers.Dense(mlp_dim)(x)
    x = tf.nn.gelu(x)
    return layers.Dense(x.shape[-1])(x)

def mixer_block(x, tokens_mlp_dim, channels_mlp_dim):
    y = layers.LayerNormalization()(x)
    y = layers.Permute((2, 1))(y)
    
    token_mixing = mlp_block(y, tokens_mlp_dim)
    token_mixing = layers.Permute((2, 1))(token_mixing)
    x = layers.Add()([x, token_mixing])
    
    y = layers.LayerNormalization()(x)
    channel_mixing = mlp_block(y, channels_mlp_dim)
    output = layers.Add()([x, channel_mixing])
    return output

def mlp_mixer(x, num_blocks, patch_size, hidden_dim, 
              tokens_mlp_dim, channels_mlp_dim,
              num_classes=10):
    x = layers.Conv2D(hidden_dim, kernel_size=patch_size,
                      strides=patch_size, padding="valid")(x)
    x = layers.Reshape((x.shape[1]*x.shape[2], x.shape[3]))(x)

    for _ in range(num_blocks):
        x = mixer_block(x, tokens_mlp_dim, channels_mlp_dim)
    
    x = layers.LayerNormalization()(x)
    x = layers.Dropout(0.25)(x)
    x = layers.GlobalAveragePooling1D()(x)
    return layers.Dense(num_classes, activation="softmax", dtype="float32")(x)

In [9]:
def create_mlp_mixer():
    inputs = layers.Input(shape=(CROP_TO, CROP_TO, 3))
    rescaled = layers.experimental.preprocessing.Rescaling(scale=1./255)(inputs)
    outputs = mlp_mixer(rescaled, NUM_MIXER_LAYERS,
                        PATCH_SIZE, HIDDEN_SIZE, 
                        MLP_SEQ_DIM, MLP_CHANNEL_DIM)
    return tf.keras.Model(inputs, outputs, name="mlp_mixer")

In [12]:
def run_experiment(model):
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        train_ds,
        validation_data=validation_ds,
        epochs=EPOCHS,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, top_1_accuracy = model.evaluate(test_ds)
    print(f"Test accuracy: {round(top_1_accuracy * 100, 2)}%")
    
    return history, model

In [None]:
with strategy.scope():
    mlp_mixer_classifier = create_mlp_mixer()
    history = run_experiment(mlp_mixer_classifier)
    model.save(f"mlp_mixer_{NUM_MIXER_LAYERS}")

Epoch 1/100
  1/194 [..............................] - ETA: 7:27:19 - loss: 2.5295 - accuracy: 0.0859