In [None]:
import sys
sys.path.append("../")

In [None]:
import os
import logging
from typing import List, Tuple

import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

from src.models.vit import (
    PatcheLayer,
    PatchEncodeLayer,
    FeedForwardLayer,
    EncoderLayer,
    Encoder,
    VisionTransformer,
)
from src.utils.logger import get_logger
from src.utils.session import reset_session
from src.utils.plot import plot_history

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

In [None]:
_ = get_logger()

## Load Dataset

In [None]:
(X_train_raw, y_train_raw), (X_test_raw, y_test_raw) = tf.keras.datasets.cifar10.load_data()

print(f"X_train: {X_train_raw.shape}")
print(f"y_train: {y_train_raw.shape}")

print(f"X_test: {X_test_raw.shape}")
print(f"y_test: {y_test_raw.shape}")

## Example: Patch Sequences

In [None]:
bs = 16
ps = 4
inputs = X_train_raw[:bs]
targets = y_train_raw[:bs]

layer_p = PatcheLayer(ps)
patches = layer_p(inputs)

In [None]:
def draw_patches(
    raw: np.ndarray,
    patches: np.ndarray,
    patch_size: int = 4,
) -> plt.Figure:
    grid_size = int(raw.shape[0] / patch_size)
    
    fig = plt.figure()
    subfigs = fig.subfigures(1, 2, width_ratios=(1, 1))
    ax = subfigs[0].subplots(1, 1)
    ax.imshow(raw)
    ax.axis("off")
    
    axs = subfigs[1].subplots(grid_size, grid_size)
    for i, patch in enumerate(patches):
        row, col = divmod(i, grid_size)
        ax = axs[row, col]
        ax.imshow(patch.reshape(patch_size, patch_size, -1))
        ax.axis("off")
    
    subfigs[0].suptitle("Original", y=0.8)
    
    subfigs[1].suptitle("Patches", y=0.8)
    subfigs[1].subplots_adjust(top=0.76, bottom=0.23, wspace=0.1, hspace=0.1)
    return fig

In [None]:
i = 3
ori = inputs[i]
pcs = patches[i].numpy()
_ = draw_patches(ori, pcs, patch_size=ps)

In [None]:
i = 7
ori = inputs[i]
pcs = patches[i].numpy()
_ = draw_patches(ori, pcs, patch_size=ps)

## Preprocess

In [None]:
X_train = X_train_raw / 255.
X_test = X_test_raw / 255.

y_train = y_train_raw.reshape(-1)
y_test = y_test_raw.reshape(-1)

In [None]:
X_train, X_valid, y_train, y_valid = train_test_split(
    X_train,
    y_train,
    test_size=0.1,
    random_state=1234
)

print(f"X_train: {X_train.shape}")
print(f"y_train: {y_train.shape}")

print(f"X_valid: {X_valid.shape}")
print(f"y_valid: {y_valid.shape}")

print(f"X_test: {X_test.shape}")
print(f"y_test: {y_test.shape}")

## Train model

In [None]:
# model config
input_shape = tuple(X_train.shape[1:])
num_classes = 10
resize = 32
patch_size = 4
projection_dim = 64
num_heads = 4
num_encoder_blocks = 2
mlp_hidden_units = [512, 128]
dropout_rate = 0.1
learning_rate = 1e-3


# train config
batch_size = 256
epochs = 30

base_model_dir = "../model/"
os.makedirs(base_model_dir, exist_ok=True)
ckpt_path = os.path.join(base_model_dir, "vit_image_clf", "ckpt")

es_cb = tf.keras.callbacks.EarlyStopping(patience=10, mode="min", verbose=1)
ckpt_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=ckpt_path,
    mode="min",
    save_best_only=True,
    save_weights_only=True,
    verbose=1,    
)
callbacks = [es_cb, ckpt_cb]

In [None]:
reset_session()

In [None]:
vit = VisionTransformer(
    input_shape=input_shape,
    num_classes=num_classes,
    resize=resize,
    patch_size=patch_size,
    projection_dim=projection_dim,
    num_heads=num_heads,
    num_encoder_blocks=num_encoder_blocks,
    mlp_hidden_units=mlp_hidden_units,
    dropout_rate=dropout_rate,
    learning_rate=learning_rate,
)

In [None]:
vit.build()

In [None]:
model = vit.model

In [None]:
model.summary()

In [None]:
history = model.fit(
    x=X_train,
    y=y_train,
    validation_data=(X_valid, y_valid),
    batch_size=batch_size,
    epochs=epochs,
    callbacks=callbacks,
)

In [None]:
plot_history(history, ylabel="Cross-Entropy")

In [None]:
model.load_weights(ckpt_path)

In [None]:
model.evaluate(X_train, y_train, batch_size=batch_size)

In [None]:
model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
# inps = X_train[:bs]
# tars = y_train[:bs]
# outs = vit.model(inps)
# tf.keras.losses.SparseCategoricalCrossentropy()(tars, outs)