In [1]:
import tensorflow as tf
import utils
import os
import glob

from dataset import get_datasets, get_dataset_shape
from stormer import Stormer

2024-04-22 15:44:57.748368: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-22 15:44:57.772815: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Example of how you can build a new model and transfer learn from an old model
import copy

print_model_table = lambda model_list: utils.print_enumerated_list(model_list, "Model")

models_names = [path.split("/")[-1] for path in glob.glob("models/stormer*")]
models_names.sort()
print_model_table(models_names)

old_model_name = models_names[int(input("Enter the Index of the model you want to transfer learn from: "))]

Available Models:
-------------------------------------------------
| Index        | Model Name                     |
|--------------|--------------------------------|
| 0            | stormer_r2_h4_dm32_dataset=mel |
|______________|________________________________|


In [3]:
old_hps = utils.load_hps(old_model_name)
new_hps = copy.deepcopy(old_hps)
new_num_state_cells = [1]
new_hps["num_state_cells"] += new_num_state_cells
new_hps["num_repeats"] += len(new_num_state_cells)
new_model_name = utils.get_model_name(**new_hps)

utils.save_hps(new_model_name, new_hps)

In [4]:
old_stormer = Stormer(**old_hps)
new_stormer = Stormer(**new_hps)

old_stormer.load_weights(utils.get_model_path(old_model_name))

stormer, transfered_layers = utils.transfer(old_stormer, new_stormer)

2024-04-22 15:45:16.284086: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-22 15:45:16.300703: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-22 15:45:16.300844: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysf

In [6]:
for i in range(len(transfered_layers)):
    stormer.layers[i].trainable = False

In [7]:
## load the datasets
train, valid, test = get_datasets(**new_hps)

Loading dataset version 2
Dataset loaded


2024-04-22 15:45:30.282397: I tensorflow_io/core/kernels/cpu_check.cc:128] Your CPU supports instructions that this TensorFlow IO binary was not compiled to use: AVX AVX2 FMA




In [8]:
results_filename = f'data/results/{new_model_name}.csv'

metrics=["accuracy"]

new_stormer.compile(
    optimizer=tf.keras.optimizers.AdamW(new_hps["learning_rate"]),
    loss="categorical_crossentropy",
    metrics=metrics,
)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=utils.get_model_path(new_model_name),
    save_weights_only=True,
    save_freq="epoch",
    verbose=0,
)

state_transformer_history = new_stormer.fit(
    train,
    validation_data=valid,
    epochs=new_hps["num_epochs"],
    callbacks=[
        model_checkpoint_callback,
        utils.MetricsLogger(
            results_filename,
        )
    ],
)

Epoch 1/10000
