In [None]:
import configuration
import tensorflow as tf
import utils
import os
import glob

from dataset import get_datasets
from gated_stormer import Stormer

In [None]:
# print model names

print_model_table = lambda model_list: utils.print_enumerated_list(model_list, "Model")

models_names = [path.split("/")[-1] for path in glob.glob("models/*stormer*")]
models_names.sort()
print_model_table(models_names)
model_name = models_names[int(input("Choose model: "))]

In [None]:
hps = utils.load_hps(model_name)
stormer = Stormer(**hps)

In [None]:
## load the datasets
train, valid, test = get_datasets(
    **hps
)

In [None]:
for example,label in train.take(1):
    stormer(example)

In [None]:
# check if the model containing directory exists
model_path = utils.get_model_path(model_name)
print("Model path:", model_path)
load_weights = os.path.exists(os.path.dirname(model_path))
if load_weights:
    stormer.load_weights(model_path)

In [None]:
results_filename = f'data/results/{model_name}.csv'

metrics=["accuracy"]

stormer.compile(
    optimizer=tf.keras.optimizers.AdamW(hps["learning_rate"]),
    loss="categorical_crossentropy",
    metrics=metrics,
)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_path,
    save_weights_only=True,
    save_freq="epoch",
    verbose=0,
)

state_transformer_history = stormer.fit(
    train,
    validation_data=valid,
    epochs=hps["num_epochs"],
    callbacks=[
        model_checkpoint_callback,
        utils.MetricsLogger(
            results_filename,
        )
    ],
)