In [1]:
import tensorflow as tf
from utils import ResultsWriter, init_results_file

2024-04-19 15:08:37.307573: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-19 15:08:37.331197: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from dataset import get_datasets
from stormer import Stormer

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
#### hyper parameters that defines the structure of the model
num_classes = 31 # ds.get_labels()

learning_rate = 0.01
weight_decay = 0.005
batch_size = 64
num_epochs = 10000  # For real training, use num_epochs=100. 10 is a test value
# patch_size = 6  # Size of the patches to be extract from the input images
# num_patches = (image_size // patch_size) ** 2

num_heads = 4
num_repeats = 2
num_state_cells = 10
input_seq_size = 31
projection_dim = 32
inner_ff_dim = 2 * projection_dim
dropout = 0.1
probability_of_noise = 0.8


In [7]:
train, valid, test = get_datasets(
    batch_size=batch_size,
    type='mel',
    probability_of_noise=probability_of_noise,
)













In [8]:
stormer = Stormer(
    num_classes=num_classes,
    num_repeats=num_repeats,
    num_heads=num_heads,
    num_state_cells=num_state_cells,
    input_seq_size=input_seq_size,
    projection_dim=projection_dim,
    inner_ff_dim=inner_ff_dim,
    dropout=dropout,
    kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
)

In [9]:
# load the model weights
model_path =f"./models/stormer_r{num_repeats}_h{num_heads}_dm{projection_dim}/stormer_r{num_repeats}_h{num_heads}.ckpt"
load_weights = False
if load_weights:
    stormer.load_weights(model_path)

In [10]:
results_filename = f'results_r{num_repeats}_h{num_heads}.csv' 
init_results_file(
    filename=results_filename,
    repeats_to_examine=[num_repeats],
    state_cells_to_examine=[num_state_cells],
    epochs=num_epochs
)

stormer.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate),
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_path,
    save_weights_only=True,
    save_freq="epoch",
    verbose=0,
)

state_transformer_history = stormer.fit(
    train,
    validation_data=valid,
    epochs=num_epochs,
    callbacks=[
        model_checkpoint_callback,
        ResultsWriter(results_filename, num_repeats, num_state_cells)
    ],
)

Epoch 1/10000


2024-04-19 15:11:20.273211: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:606] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2024-04-19 15:11:20.334332: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f4fd80070f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-04-19 15:11:20.334346: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA RTX A5000, Compute Capability 8.6
2024-04-19 15:11:20.372617: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-04-19 15:11:20.553268: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8600
2024-04-19 15:11:20.714488: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.



KeyboardInterrupt: 

In [27]:
stormer.summary()

Model: "stormer_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_103 (Dense)           multiple                  1280      
                                                                 
 rotary_positional_encoding  multiple                  0         
 _2 (RotaryPositionalEncodi                                      
 ng)                                                             
                                                                 
 stormer_ru_3 (StormerRU)    multiple                  69632     
                                                                 
 stormer_ru_4 (StormerRU)    multiple                  69632     
                                                                 
 dense_170 (Dense)           multiple                  1023      
                                                                 
Total params: 141567 (553.00 KB)
Trainable params: 140927