In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import utils
import matplotlib.pyplot as plt

In [None]:
# hyperparameters taken from
# https://github.com/keras-team/keras-io/blob/master/examples/vision/image_classification_with_vision_transformer.py#L54
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]
transformer_layers = 8
mlp_head_units = [2048, 1024]

In [None]:
train_dict = utils.unpickle("./cifar-100-python/train")
test_dict = utils.unpickle("./cifar-100-python/test")
# Keys are unaccessible because they are in bytestring format
def decode_keys(data):
    decoded = []
    for i in data.keys():
        decoded.append(i.decode("utf-8"))
    return dict(zip(decoded, data.values()))
train_dict_decoded = decode_keys(train_dict)
test_dict_decoded = decode_keys(test_dict)
print(train_dict_decoded.keys())

In [None]:
# Explore data format
print(train_dict_decoded["data"][0].shape)
print(train_dict_decoded["fine_labels"][0])

In [None]:
# Create a function to isolate each color channel and build the image
def prepare_dict(data):
    from copy import copy
    data_copy = copy(data)
    processed_imgs = []
    for i in data_copy["data"]: #bruh
        processed_imgs.append(np.reshape(i, (3, 32, 32)).transpose(1, 2, 0))
    data_copy["data"] = processed_imgs
    data_copy["label"] = data_copy["fine_labels"]
    del data_copy["filenames"]
    del data_copy["batch_label"]
    del data_copy["coarse_labels"]
    del data_copy["fine_labels"]
    return data_copy

train_data = prepare_dict(train_dict_decoded)
test_data = prepare_dict(test_dict_decoded)
print(train_data["data"][0].shape)
print(train_data["label"][0])
print(train_data["data"][0])

In [None]:
# Creating train and test dataset
X_train, y_train = np.array(train_data["data"]), train_data["label"]
X_test, y_test = np.array(test_data["data"]), test_data["label"]

In [None]:
# Loading metadata
meta_dict = decode_keys(utils.unpickle("./cifar-100-python/meta"))
class_names = [i.decode("utf-8") for i in meta_dict["fine_label_names"]]
print(class_names[:35])

In [None]:
# Each index corresponds to the "label" value in the dataset
# Visualize 9 images
def visualize_data(train_images, train_labels):
    import random
    random_idx = random.sample(list(range(len(train_labels))), 9)
    plt.figure(figsize=(9, 9))
    for i in range(9):
        subp = int("33" + str(i + 1))
        plt.subplot(subp)
        class_ = train_labels[random_idx[i]]
        title = "Class: " + str(class_) + ", Label:" + class_names[class_]
        plt.title(title)
        plt.axis("off")
        plt.grid(False)
        plt.imshow(train_images[random_idx[i]])
visualize_data(X_train, y_train)

In [None]:
# preprocess layers
norm = tf.keras.layers.Normalization()
res = tf.keras.layers.Resizing(image_size, image_size)
X_train = norm(res(X_train))
X_test = norm(res(X_test))

In [None]:
# create efficient data pipelines
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [None]:
from model import ViT
# create the model
model = ViT(image_size=image_size,
    patch_size=patch_size,
    num_classes=len(meta_dict["fine_label_names"]),
    dim=projection_dim,
    depth=transformer_layers,
    attention_heads=num_heads,
    transformer_units=transformer_units,
    mlp_head_units=mlp_head_units
)

In [None]:
# function to run an experiment
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            tf.keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        train_ds,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_data=test_ds,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(test_ds)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history
model.build((None, 32, 32, 3))
print(model.summary())

In [None]:
# Call the ViT
run_experiment(model=model)