In [None]:
import torch
import gc
import logging

# del session.logger
# del session.epoch_logger
# del session
# gc.collect()
# torch.cuda.empty_cache()

import multiprocessing
from train.training_session_v2 import TrainingSessionV2, load_training_session
from config import TrainingConfigV2
from config import SimpleConvConfig

data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "armeini2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=80,
    hidden_dim=256,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=0.0,  # Float
    merger_conditional=None,
    # Inital
    initial_linear=256,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=4,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.1,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input="continuous",
    transformer_encoder_emb="sinusoidal",
    transformer_encoder_layers=4,
    transformer_encoder_heads=4,
    # Conformer encoder variant
    rnn_type="conformer",
    depthwise_conv_kernel_size=15,
    use_group_norm=False,
    convolution_first=False,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

config = TrainingConfigV2(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Ada lora settings
    # Around 100k total batches an epoch for gwilliams
    adalora_init_r=12,
    adalora_target_r=4,
    adalora_tinit=(450 * 3),  # 5% total steps
    adalora_tfinal=(450 * 8),  # 50-80% total steps
    adalora_deltaT=(450 * 1),  # 1-5% total steps
    adalora_lora_alpha=32,
    adalora_lora_dropout=0.1,
    adalora_total_step=(450 * 50),
    # Pre-processing parameters
    # Brain
    new_freq=200,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=None,
    baseline_window=0.5,
    notch_filter=True,
    scaling="both",
    delay=0.15,
    # Audio
    audio_model="openai/whisper-tiny.en",
    # Hyperparameters
    learning_rate=1e-4,
    weight_decay=1e-4,
    epochs=50,
    steps_per_epoch=450,
    batch_size=128,
    random_test_size=10,
    seed=42,
    mel_alignment_objectives={
        "clip_loss": 0.6,
        "mse_loss": 0.4,
    },
    latent_alignment_objectives={
        "cosine_similarity": 0.0,
        "mse_loss": 0.0,
        "clip_loss": 0.0,
        "mmd_loss": 0.0,
    },
    decode_timestamps=True,
)

config.brain_encoder_config.mel_normalization = False
config.learning_rate = 3e-4
config.batch_size = 256
config.steps_per_epoch = 450

config.decode_timestamps = True

# config.brain_encoder_config.hidden_dim = 1024
# config.brain_encoder_config.initial_linear = 1024

session = TrainingSessionV2(
    config=config,
    studies={study: "audiotext" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase3/objectives/baseline_gwilliams_latent_loss_no_latent_alignment",
    clear_cache=False,
    cache_name="/home/ubuntu/cache",
    download_studies=True,
)


# session = load_training_session(
#     save_path="saves/phase3/objectives/baseline_gwilliams_latent_loss_no_latent_alignment/epoch_39",
#     studies={"gwilliams2023": "audiotext"},
#     data_path="data",
#     cache_name="/home/ubuntu/cache",
# )

try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=(multiprocessing.cpu_count() - 2),
        max_cache_size=800,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

# try:
#     session.pre_process_all_recordings(
#         buffer_size=30, num_workers=multiprocessing.cpu_count() - 4, max_cache_size=800
#     )
# except KeyboardInterrupt as e:
#     print("Exited")

Loading Gwilliams2023 with batch type audiotext
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 45, Unseen Subject: 12, Unseen Both: 4.

RNNEncoder initialized as conformer with 4 layers, 256 d_model, 4 nhead
	Embedding: sinusoidal, params: 6075392
SimpleConv initialized with 8927984 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 4, hidden_dim: 256, params 2626048
Found 40 target modules for AdaLora: ['model.decoder.layers.0.self_attn.k_proj', 'model.decoder.layers.0.self_attn.v_proj', 'model.decoder.layers.0.self_attn.q_proj', 'model.decoder.layers.0.self_attn.out_proj', 'model.decoder.layers.0.encoder_attn.k_proj', 'model.decoder.layers.0.encoder_attn.v_proj', 'model.decoder.layers.0.encoder_attn.q_proj', 'model.decoder.layers.0.encoder_attn.out_proj', 'model.decoder.layers.0.fc1', 'model.decoder.layers.0.fc2', 'model.decoder.layers.1.self_attn.k_proj', 'model.decoder.layers.1.self_attn.v_proj', 'model.decoder.layers.1.self_

2025-03-04 19:09:58,154	INFO worker.py:1841 -- Started a local Ray instance.
Training Epoch 1: 100%|██████████| 135/135 [10:36<00:00,  4.72s/it]


Testing done in 1.22m.
Epoch 1 done in 11.88m. 0.09m/recording.


New best epoch 1 with CER 0.9505 and BLEU 0.0009.
Mel Loss: 6.2285, Clip Loss: 9.4895, MSE: 1.3369
Mel accuracy: 0.0048, Top 5: 0.0236, Top 10: 0.0472


Training Epoch 2: 100%|██████████| 135/135 [10:30<00:00,  4.67s/it]


Testing done in 1.60m.
Epoch 2 done in 12.12m. 0.09m/recording.


New best epoch 2 with CER 2.0421 and BLEU 0.0035.
Mel Loss: 5.8900, Clip Loss: 9.4911, MSE: 0.4883
Mel accuracy: 0.0051, Top 5: 0.0254, Top 10: 0.0505


Training Epoch 3: 100%|██████████| 135/135 [10:28<00:00,  4.66s/it]


Testing done in 1.82m.
Epoch 3 done in 12.30m. 0.09m/recording.


New best epoch 3 with CER 2.1884 and BLEU 0.0065.
Mel Loss: 5.7935, Clip Loss: 9.4749, MSE: 0.2713
Mel accuracy: 0.0142, Top 5: 0.0555, Top 10: 0.0971


Training Epoch 4:   0%|          | 0/135 [00:00<?, ?it/s]

Starting rank reallocation at recording 1350.


Training Epoch 4: 100%|██████████| 135/135 [10:31<00:00,  4.67s/it]


Testing done in 4.53m.
Epoch 4 done in 15.05m. 0.11m/recording.


Training Epoch 5: 100%|██████████| 135/135 [10:26<00:00,  4.64s/it]


Testing done in 4.91m.
Epoch 5 done in 15.36m. 0.11m/recording.


New best epoch 5 with CER 2.6006 and BLEU 0.0070.
Mel Loss: 5.7460, Clip Loss: 9.4196, MSE: 0.2358
Mel accuracy: 0.0568, Top 5: 0.1697, Top 10: 0.2542


Training Epoch 6: 100%|██████████| 135/135 [10:24<00:00,  4.63s/it]


Testing done in 5.13m.
Epoch 6 done in 15.54m. 0.12m/recording.


New best epoch 6 with CER 2.6887 and BLEU 0.0081.
Mel Loss: 5.7268, Clip Loss: 9.3956, MSE: 0.2235
Mel accuracy: 0.0757, Top 5: 0.2138, Top 10: 0.3103


Training Epoch 7: 100%|██████████| 135/135 [10:30<00:00,  4.67s/it]


Testing done in 5.03m.
Epoch 7 done in 15.54m. 0.12m/recording.


New best epoch 7 with CER 2.3769 and BLEU 0.0088.
Mel Loss: 5.7145, Clip Loss: 9.3787, MSE: 0.2183
Mel accuracy: 0.0900, Top 5: 0.2431, Top 10: 0.3470


Training Epoch 8: 100%|██████████| 135/135 [10:26<00:00,  4.64s/it]


Testing done in 5.07m.
Epoch 8 done in 15.52m. 0.11m/recording.


Training Epoch 9: 100%|██████████| 135/135 [10:24<00:00,  4.63s/it]


Testing done in 4.99m.
Epoch 9 done in 15.41m. 0.11m/recording.


Training Epoch 10:  82%|████████▏ | 111/135 [08:45<01:33,  3.89s/it]