In [None]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
import json

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    except RuntimeError as e:
        print(e)

from models import build_vit
import plots

In [None]:
def load_config(path):
    with open(path, "r") as f:
        config = json.loads(f.read())
    return config

In [None]:
config = load_config("config.json")

# Read Data

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

image_shape = (32, 32, 3)
classes = 100

# Build Model

The preprocessing and augmentation of the data is done within the model using keras's preprocessing layers.
The following is applied:

- Rescaling to [0, 1]
- Flip image horizontally.
- Rotation.
- Zoom.

In [None]:
model_config = config["model"]
vit = build_vit(image_shape, classes, return_attention_score=False, **model_config)

# Train Model

In [None]:
training_config = config["training"]
loss = tf.keras.losses.get(training_config["loss"])
optimizer = tf.keras.optimizers.get(training_config["optimizer"])

vit.compile(loss=loss, optimizer=optimizer, metrics=["sparse_categorical_accuracy"])

In [None]:
epochs = training_config["epochs"]
batch_size = training_config["batch_size"]
validation_split = training_config["validation_split"]

history = vit.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_split=validation_split)

# Evaluation

In [None]:
fig = plots.learning_curve(history)
fig.show()

In [None]:
test_loss, accuracy = vit.evaluate(x_test, y_test)
print(test_loss, accuracy)