In [None]:
import wandb
from wandb.keras import WandbCallback

import numpy as np
from typing import Union
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, Input, Sequential
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras import (
    layers,
    losses,
    metrics,
    datasets,
    mixed_precision,
    optimizers,
)

In [None]:
CONFIGS = {
    "dataset_name": "CIFAR-10",
    "image_size": 32,
    "target_size": 72,
    "patch_size": 9,
    "num_mixer_layers": 4,
    "embedding_dim": 128,
    "channels_mlp_dim": 128,
    "num_classes": 10,
    "dropout": 0.25,
    "batch_size": 512,
    "learning_rate": 0.001,
    "epochs": 50,
    "label_smoothing": 0.0,
    "mixed_precision": False,
    "class_names": [
        "airplane", "automobile", "bird", "cat",
        "deer", "dog", "frog", "horse", "ship", "truck"
    ]
}

In [None]:
if CONFIGS["mixed_precision"]:
    mixed_precision.set_global_policy("mixed_float16")

In [None]:
def get_cifar10(num_classes: int):
    (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)
    return (x_train, y_train), (x_test, y_test)


(x_train, y_train), (x_test, y_test) = get_cifar10(num_classes=10)
print("x_train.shape:", x_train.shape)
print("y_train.shape:", y_train.shape)
print("x_test.shape:", x_test.shape)
print("y_test.shape:", y_test.shape)

In [None]:
def get_preprocessing_layer(
    data_batch: Union[np.ndarray, tf.Tensor], target_size: int
) -> Sequential:
    normalization = preprocessing.Normalization()
    normalization.adapt(data_batch)
    resize = preprocessing.Resizing(target_size, target_size)
    return Sequential([normalization, resize], name="preprocessing")


def get_augmentation_layer() -> Sequential:
    return keras.Sequential(
        [
            preprocessing.RandomFlip("horizontal"),
            preprocessing.RandomRotation(factor=0.02),
            preprocessing.RandomZoom(height_factor=0.2, width_factor=0.2),
        ],
        name="augmentation",
    )

In [None]:
def patch_embedding(
    inputs: tf.Tensor, embedding_dim: int, patch_size: int
) -> tf.Tensor:
    x = layers.Conv2D(embedding_dim, kernel_size=patch_size, strides=patch_size)(inputs)
    return layers.Reshape((x.shape[1] * x.shape[2], x.shape[3]))(x)


def mlp_block(inputs: tf.Tensor, mlp_dim: int) -> tf.Tensor:
    x = layers.Dense(mlp_dim)(inputs)
    x = layers.Activation("gelu")(x)
    return layers.Dense(x.shape[-1])(x)


def mixer_block(inputs: tf.Tensor, tokens_mlp_dim, channels_mlp_dim) -> tf.Tensor:
    y = layers.LayerNormalization()(inputs)
    y = layers.Permute((2, 1))(y)
    # Token Mixing
    y = mlp_block(y, tokens_mlp_dim)
    y = layers.Permute((2, 1))(y)
    x = layers.Add()([inputs, y])
    # Channel Mixing
    y = layers.LayerNormalization()(x)
    y = mlp_block(y, channels_mlp_dim)
    return layers.Add()([x, y])


def get_mlp_mixer_model(
    num_mixer_blocks: int,
    patch_size: int,
    embedding_dim: int,
    channels_mlp_dim: int,
    num_classes: int,
    preprocessing_layer: Union[Sequential, None],
    augmentation_layer: Union[Sequential, None],
) -> Model:
    inputs = Input(shape=(CONFIGS["image_size"], CONFIGS["image_size"], 3))
    preprocessed_inputs = (
        preprocessing_layer(inputs) if preprocessing_layer is not None else inputs
    )
    augmented_inputs = (
        augmentation_layer(preprocessed_inputs)
        if augmentation_layer is not None
        else preprocessed_inputs
    )
    x = patch_embedding(augmented_inputs, embedding_dim, patch_size)
    tokens_mlp_dim = x.shape[-2]
    for _ in range(num_mixer_blocks):
        x = mixer_block(x, tokens_mlp_dim, channels_mlp_dim)
    x = layers.LayerNormalization()(x)
    x = layers.Dropout(CONFIGS["dropout"])(x)
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return Model(inputs, outputs, name="mlp_mixer")

In [None]:
wandb.login(relogin=True)
run = wandb.init(project='mlp-mixer', name="cifar-10", config=CONFIGS)

In [None]:
model = get_mlp_mixer_model(
    num_mixer_blocks=CONFIGS["num_mixer_layers"],
    patch_size=CONFIGS["patch_size"],
    embedding_dim=CONFIGS["embedding_dim"],
    channels_mlp_dim=CONFIGS["channels_mlp_dim"],
    num_classes=CONFIGS["num_classes"],
    preprocessing_layer=get_preprocessing_layer(
        data_batch=x_train, target_size=CONFIGS["target_size"]
    ),
    augmentation_layer=get_augmentation_layer(),
)
model.summary()

In [None]:
model.compile(
    optimizer=optimizers.Adam(learning_rate=CONFIGS["learning_rate"]),
    loss=losses.CategoricalCrossentropy(),
    metrics=[
        metrics.CategoricalAccuracy(name="accuracy"),
        metrics.TopKCategoricalAccuracy(3, name="top-3-accuracy"),
        metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)

In [None]:
wandb_callback = WandbCallback(
    data_type='image',
    save_model=True,
    validation_data=(x_test, y_test),
    labels=CONFIGS["class_names"]
)

history = model.fit(
    x=x_train,
    y=y_train,
    batch_size=CONFIGS["batch_size"],
    epochs=CONFIGS["epochs"],
    validation_split=0.1,
    callbacks=[wandb_callback]
)

In [None]:
def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("loss")
plot_result("accuracy")
plot_result("top-3-accuracy")
plot_result("top-5-accuracy")