## Example illustrating how to use the MLP-Mixer API

This example also shows the performance of the MLP-Mixer architecture on the CIFAR-10 dataset

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from mlxtend.plotting import plot_confusion_matrix
from sklearn.metrics import confusion_matrix
from tensorflow.keras.layers import *
from tensorflow.keras.layers.experimental.preprocessing import (
    RandomFlip,
    RandomRotation,
    RandomZoom,
)
from tensorflow.keras.models import Model
from mlp_utils import *

# Uncomment line below to run on CPU
# tf.config.set_visible_devices([], "GPU")

2021-07-21 20:49:33.260588: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [3]:
# Load the dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# normalize pixel values
x_train = x_train / 255
x_test = x_test / 255

# create validation data from test data
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

# Define the classes
classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

In [3]:
def create_model():
    dims = 512
    k, s = (4, 4)  # (kernel, strides)
    depth = 4

    inputs = Input((32, 32, 3))

    x = tf.keras.Sequential(
        [
            RandomFlip(),
            RandomRotation(factor=0.03),
            RandomZoom(height_factor=0.2, width_factor=0.2),
        ],
        name="data_augmentation",
    )(inputs)

    x = CreatePatches(k, s)(x)
    x = PerPatchFullyConnected(dims)(x)

    for _ in range(depth):
        x = MLPBlock()(x)

    # x = Projection()(x)
    x = GlobalAveragePooling1D()(x)

    # classification layer
    x = GaussianDropout(0.9)(x)
    output = Dense(len(classes), activation="softmax", kernel_regularizer="l2")(x)

    return Model(inputs=inputs, outputs=output)


model = create_model()
print(model.summary())

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
data_augmentation (Sequentia (None, 32, 32, 3)         0         
_________________________________________________________________
create_patches (CreatePatche (None, 8, 8, 48)          0         
_________________________________________________________________
per_patch_fully_connected (P (None, 64, 512)           25088     
_________________________________________________________________
mlp_block (MLPBlock)         (None, 64, 512)           535680    
_________________________________________________________________
mlp_block_1 (MLPBlock)       (None, 64, 512)           535680    
_________________________________________________________________
mlp_block_2 (MLPBlock)       (None, 64, 512)           535680

### Explanation of the layers
- The `CreatePatches` layer creates non-overlapping patches of an image using the `tf.image.extract_patches` method. The kernel and stride are provided to ensure the patches are sufficiently small and non-overlapping
- The `PerPatchFullyConnected` layer applies a `Dense` layer to all the individual patches of the image, projecting them into `dims` sized vector. This creates our "Patches x Channels" table as described in the MLP-Mixer paper.
- The MLPBlock takes no arguements because it maintains the input size. It does the channel mixing.

Here we are using some specific hyperparameters:
- the internal projection dimension `dims` is set to 512
- we are using a depth of 4 MLP-blocks
- We are using global average pooling as used in the paper
- using guassian dropout with a dropout probablity of 0.9 (due to tendency of model to overfit on dataset of this small size, and even the images are small in size)
- Adadelta optimizer with a learning rate of 1.0 (as the optimizer manages lr_decay and gives better results the Adam)

In [4]:
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    optimizer=tf.keras.optimizers.Adadelta(1),
    metrics=["accuracy"],
)

b_s = 128 # batch_size

val_dataset = val_dataset.batch(b_s)

In [5]:
history = model.fit(
    x_train,
    y_train,
    validation_data=val_dataset,
    batch_size=b_s,
    epochs=100,
    verbose="auto",
)

2021-07-21 18:08:37.069792: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-07-21 18:08:37.093193: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3800060000 Hz


Epoch 1/100


2021-07-21 18:08:38.790276: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-07-21 18:08:39.132242: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2021-07-21 18:08:39.132866: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8


  1/391 [..............................] - ETA: 16:18 - loss: 5.7593 - accuracy: 0.0938

2021-07-21 18:08:39.465858: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8202


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 7

As can be seen, the MLP doesnt perform very well for small datasets, and ends up overfitting even with a dropout probablity of 0.9.

The best training/val accuracy pair occurs at epoch 84. Train acc: 0.9523, test acc: 0.6132

Now, to train a similar model but to use the `Projection` layer instead of Global avg pooling, lets see if that can do any better in this scenario

In [7]:
history1 = history

In [4]:
def create_model():
    dims = 512
    k, s = (4, 4)  # (kernel, strides)
    depth = 4

    inputs = Input((32, 32, 3))

    x = tf.keras.Sequential(
        [
            RandomFlip(),
            RandomRotation(factor=0.03),
            RandomZoom(height_factor=0.2, width_factor=0.2),
        ],
        name="data_augmentation",
    )(inputs)

    x = CreatePatches(k, s)(x)
    x = PerPatchFullyConnected(dims)(x)

    for _ in range(depth):
        x = MLPBlock()(x)
        x = Permute((2,1))(x)

    # x = Projection()(x)
    x = GlobalAveragePooling1D()(x)

    # classification layer
    x = GaussianDropout(0.9)(x)
    output = Dense(len(classes), activation="softmax", kernel_regularizer="l2")(x)

    return Model(inputs=inputs, outputs=output)


model = create_model()
print(model.summary())

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
data_augmentation (Sequentia (None, 32, 32, 3)         0         
_________________________________________________________________
create_patches (CreatePatche (None, 8, 8, 48)          0         
_________________________________________________________________
per_patch_fully_connected (P (None, 64, 512)           25088     
_________________________________________________________________
mlp_block (MLPBlock)         (None, 64, 512)           535680    
_________________________________________________________________
permute (Permute)            (None, 512, 64)           0         
_________________________________________________________________
mlp_block_1 (MLPBlock)       (None, 512, 64)           533888

In [5]:
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    optimizer=tf.keras.optimizers.Adadelta(1),
    metrics=["accuracy"],
)

# b_s = 256 # batch_size

# val_dataset = val_dataset.batch(b_s)

In [6]:
history = model.fit(
    x_train,
    y_train,
    validation_data=val_dataset,
    batch_size=256,
    epochs=100,
    verbose="auto",
)

2021-07-21 20:50:25.545241: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-07-21 20:50:25.565189: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3800060000 Hz


Epoch 1/100


2021-07-21 20:50:27.291970: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-07-21 20:50:27.643179: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2021-07-21 20:50:27.643840: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2021-07-21 20:50:28.000018: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8202


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 7