In [10]:
import configuration
import tensorflow as tf
import utils
import os
import glob

from dataset import get_datasets, get_dataset_shape
from gated_stormer import Stormer

In [11]:
## Example of how you can build a new model and transfer learn from an old model
import copy

print_model_table = lambda model_list: utils.print_enumerated_list(model_list, "Model")

models_names = [path.split("/")[-1] for path in glob.glob("models/*stormer*")]
models_names.sort()
print_model_table(models_names)

old_model_name = models_names[int(input("Enter the Index of the model you want to transfer learn from: "))]

Available Models:
_______________________________________________________
| Index        | Model Name                           |
|______________|______________________________________|
| 0            | gated_stormer_r2_h4_dm32_dataset=mel |
|______________|______________________________________|


In [12]:
old_hps = utils.load_hps(old_model_name)
new_hps = copy.deepcopy(old_hps)
new_num_state_cells = [10, 10]
new_hps["num_state_cells"] += new_num_state_cells
new_hps["num_repeats"] += len(new_num_state_cells)
new_hps["learning_rate"] = 0.001
new_model_name = utils.get_model_name(**new_hps)
new_model_name = "gated_" + new_model_name

utils.save_hps(new_model_name, new_hps)

In [13]:
old_stormer = Stormer(**old_hps)
new_stormer = Stormer(**new_hps)

In [14]:
## load the datasets
train, valid, test = get_datasets(**new_hps)

Loading dataset version 2
Dataset loaded


In [15]:
for example, _ in train.take(1):
    old_stormer(example)
    new_stormer(example)

2024-05-10 05:45:35.793230: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [16]:
old_stormer.load_weights(utils.get_model_path(old_model_name))
stormer, transfered_layers = utils.transfer(old_stormer, new_stormer)

In [17]:
for i in range(len(transfered_layers)):
    stormer.layers[i].trainable = False

In [18]:
results_filename = f'data/results/{new_model_name}.csv'

metrics=["accuracy"]

new_stormer.compile(
    optimizer=tf.keras.optimizers.AdamW(new_hps["learning_rate"]),
    loss="categorical_crossentropy",
    metrics=metrics,
)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=utils.get_model_path(new_model_name),
    save_weights_only=True,
    save_freq="epoch",
    verbose=0,
)

state_transformer_history = new_stormer.fit(
    train,
    validation_data=valid,
    epochs=new_hps["num_epochs"],
    callbacks=[
        model_checkpoint_callback,
        utils.MetricsLogger(
            results_filename,
        )
    ],
)

Epoch 1/10000


I0000 00:00:1715320036.425570 4063050 service.cc:145] XLA service 0x7fbe78024680 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1715320036.425690 4063050 service.cc:153]   StreamExecutor device (0): NVIDIA L40S, Compute Capability 8.9
2024-05-10 05:47:24.761056: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-05-10 05:47:43.049913: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907
I0000 00:00:1715320155.406552 4063050 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m1487/1489[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 36ms/step - accuracy: 0.2645 - loss: 2.9038




[1m1489/1489[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m471s[0m 169ms/step - accuracy: 0.2651 - loss: 2.9015 - val_accuracy: 0.8952 - val_loss: 0.4983
Epoch 2/10000
[1m1489/1489[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 43ms/step - accuracy: 0.9067 - loss: 0.4420 - val_accuracy: 0.9092 - val_loss: 0.4253
Epoch 3/10000
[1m1489/1489[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 43ms/step - accuracy: 0.9189 - loss: 0.3686 - val_accuracy: 0.9134 - val_loss: 0.3819
Epoch 4/10000
[1m1489/1489[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 43ms/step - accuracy: 0.9215 - loss: 0.3410 - val_accuracy: 0.9122 - val_loss: 0.3856
Epoch 5/10000
[1m1489/1489[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 43ms/step - accuracy: 0.9255 - loss: 0.3140 - val_accuracy: 0.9157 - val_loss: 0.3651
Epoch 6/10000
[1m1489/1489[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 43ms/step - accuracy: 0.9260 - loss: 0.3070 - val_accuracy: 0.9139 - val_loss: 0.349

KeyboardInterrupt: 