**Note**: Some portions of the code present in this notebook are referred from the following tutorials: 

* [Image Classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/)
* [Consistency Training with Supervision](https://keras.io/examples/vision/consistency_training/)

## Setup

In [1]:
from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf
import numpy as np

# Only enable this for tensor-core GPUs
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0


In [2]:
try: 
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: 
    strategy = tf.distribute.MirroredStrategy() 

print("Number of accelerators: ", strategy.num_replicas_in_sync)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of accelerators:  4


## Hyperparameters

In [3]:
RESIZE_TO = 72
PATCH_SIZE = 9

NUM_MIXER_LAYERS = 2
HIDDEN_SIZE = 128
MLP_SEQ_DIM = 64
MLP_CHANNEL_DIM = 128

EPOCHS = 100
BATCH_SIZE = 512 * strategy.num_replicas_in_sync

## Dataset Utilities

In [4]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

In [5]:
def get_augmentation_layers():
    data_augmentation = keras.Sequential(
        [
            layers.experimental.preprocessing.Normalization(),
            layers.experimental.preprocessing.Resizing(RESIZE_TO, RESIZE_TO),
            layers.experimental.preprocessing.RandomFlip("horizontal"),
            layers.experimental.preprocessing.RandomRotation(factor=0.02),
            layers.experimental.preprocessing.RandomZoom(
                height_factor=0.2, width_factor=0.2
            ),
        ],
        name="data_augmentation",
    )
    # Compute the mean and the variance of the training data for normalization.
    data_augmentation.layers[0].adapt(x_train)
    
    return data_augmentation

## MLP-Mixer Utilities

This is referred from **Appendix E** of the [original paper](https://arxiv.org/pdf/2105.01601.pdf). 

In [6]:
def mlp_block(x, mlp_dim):
    x = layers.Dense(mlp_dim)(x)
    x = tf.nn.gelu(x)
    return layers.Dense(x.shape[-1])(x)

def mixer_block(x, tokens_mlp_dim, channels_mlp_dim):
    y = layers.LayerNormalization()(x)
    y = layers.Permute((2, 1))(y)
    
    token_mixing = mlp_block(y, tokens_mlp_dim)
    token_mixing = layers.Permute((2, 1))(token_mixing)
    x = layers.Add()([x, token_mixing])
    
    y = layers.LayerNormalization()(x)
    channel_mixing = mlp_block(y, channels_mlp_dim)
    output = layers.Add()([x, channel_mixing])
    return output

def mlp_mixer(x, num_blocks, patch_size, hidden_dim, 
              tokens_mlp_dim, channels_mlp_dim,
              num_classes=10):
    x = layers.Conv2D(hidden_dim, kernel_size=patch_size,
                      strides=patch_size, padding="valid")(x)
    x = layers.Reshape((x.shape[1]*x.shape[2], x.shape[3]))(x)

    for _ in range(num_blocks):
        x = mixer_block(x, tokens_mlp_dim, channels_mlp_dim)
    
    x = layers.LayerNormalization()(x)
    x = layers.Dropout(0.25)(x)
    x = layers.GlobalAveragePooling1D()(x)
    return layers.Dense(num_classes, activation="softmax", dtype="float32")(x)

In [7]:
def create_mlp_mixer():
    data_augmentation = get_augmentation_layers()
    
    inputs = layers.Input(shape=(32, 32, 3))
    augmented = data_augmentation(inputs)
    outputs = mlp_mixer(augmented, NUM_MIXER_LAYERS,
                        PATCH_SIZE, HIDDEN_SIZE, 
                        MLP_SEQ_DIM, MLP_CHANNEL_DIM)
    return tf.keras.Model(inputs, outputs, name="mlp_mixer")

In [8]:
def run_experiment(model):
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001 * strategy.num_replicas_in_sync)

    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, top_1_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(top_1_accuracy * 100, 2)}%")
    
    return history, model

## Model Training and Evaluation

In [9]:
with strategy.scope():
    mlp_mixer_classifier = create_mlp_mixer()
    history, model = run_experiment(mlp_mixer_classifier)
    model.save(f"mlp_mixer_{NUM_MIXER_LAYERS}")

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1/100
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 30 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/devi