In [None]:
pip install tensorflow-addons

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

In [None]:
pip install vit-keras

In [None]:
from vit_keras import vit

In [None]:
IMAGE_SIZE = 224
NUM_CALSSES = 100
LR = 0.001
WEIGHT_DECAY = 0.0001
BATCH_SIZE = 16
NUM_EPOCHS = 10

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f'x_train_shape: {x_train.shape} - y_train_shape: {y_train.shape}')
print(f'x_test_shape: {x_test.shape} - y_test_shape: {y_test.shape}')

In [None]:
data_augmentation = keras.Sequential([
  layers.Normalization(),
  layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
  layers.RandomFlip('horizontal'),
  layers.RandomRotation(factor=0.02),
  layers.RandomZoom(height_factor=0.2, width_factor=0.2),
], name='data_augmentation')

# Compute the mean and the variance of the training data for normalization, then store them as the layer's weights.
data_augmentation.layers[0].adapt(x_train)

In [None]:
vit_model = vit.vit_b32(
    image_size=IMAGE_SIZE,
    activation='softmax',
    pretrained=True,
    include_top=False,
    pretrained_top=False,
    classes = NUM_CLASSES
)

## Fine tuning the model

In [None]:
model = tf.keras.Sequential([
  data_augmentation,
  vit_model,
  tf.keras.layers.Flatten(),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dense(256, activation=tfa.activations.gelu),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dense(NUM_CALSSES, 'softmax')
], name='ft_vit')

In [None]:
model.summary()

In [None]:
optimizer = tfa.optimizers.AdamW(
    learning_rate=LR, weight_decay=WEIGHT_DECAY
)

In [None]:
model.compile(
    optimizer=optimizer,
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[
      keras.metrics.SparseCategoricalAccuracy(name='Accuracy'),
      keras.metrics.SparseTopKCategoricalAccuracy(5, name='top-5-accuracy')
    ]
)

In [None]:
checkpoint_filepath = 'checkpoint'
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    checkpoint_filepath,
    monitor='val_accuracy',
    save_best_only=True,
    save_weights_only=True
)

In [None]:
history = model.fit(
    x=x_train,
    y=y_train,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    validation_split=0.1,
    callbacks=[checkpoint_callback]
)