In [None]:
import tensorflow as tf
import utils
import os

from dataset import get_datasets, labels_v1, labels_v2
from stormer import Stormer

In [None]:
# An example to how to initailize the parameters in case you don't have no model

def hps():
version = 2
num_classes = len(labels_v2) if version == 2 else len(labels_v1)
dataset_type = "mel"
frame_length = 256
frame_step = 128
mel_bands=40
num_coefficients=13
max_shift_in_ms=100

learning_rate = 0.01
weight_decay = 0.005
batch_size = 64
num_epochs = 10000  # For real training, use num_epochs=100. 10 is a test value

num_heads = 4
num_repeats = 2
num_state_cells = [10, 10]
input_seq_size = 31
projection_dim = 32
inner_ff_dim = 2 * projection_dim
dropout = 0.1
probability_of_noise = 0.8

hps = {
    "version": version,
    "num_classes": num_classes,
    "dataset_type": dataset_type,
    "frame_length": frame_length,
    "frame_step": frame_step,
    "mel_bands": mel_bands,
    "num_coefficients": num_coefficients,
    "max_shift_in_ms": max_shift_in_ms,
    "learning_rate": learning_rate,
    "weight_decay": weight_decay,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "num_heads": num_heads,
    "num_repeats": num_repeats,
    "num_state_cells": num_state_cells,
    "input_seq_size": input_seq_size,
    "projection_dim": projection_dim,
    "inner_ff_dim": inner_ff_dim,
    "dropout": dropout,
    "probability_of_noise": probability_of_noise,
}
model_name = utils.get_model_name(**hps)

utils.save_hps(model_name, hps)
model_path = utils.get_model_path(model_name)

In [None]:
stormer = Stormer(
    **hps
)

In [None]:
## load the datasets
train, valid, test = get_datasets(
    **hps
)

In [None]:
results_filename = f'data/results/results_r{num_repeats}_h{num_heads}_dm{projection_dim}_dataset={dataset_type}.csv'

metrics=["accuracy"]

stormer.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate),
    loss="categorical_crossentropy",
    metrics=metrics,
)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_path,
    save_weights_only=True,
    save_freq="epoch",
    verbose=0,
)

state_transformer_history = stormer.fit(
    train,
    validation_data=valid,
    epochs=num_epochs,
    callbacks=[
        model_checkpoint_callback,
        utils.MetricsLogger(
            results_filename,
        )
    ],
)