In [1]:
import torch
import gc
import logging

# del session.logger
# del session.epoch_logger
# del session
# gc.collect()
# torch.cuda.empty_cache()

import multiprocessing
from train.training_session_v2 import TrainingSessionV2, load_training_session
from config import TrainingConfigV2
from config import SimpleConvConfig

data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    "armeini2022": {
        "testing_subjects": [],
        "testing_tasks": [0, 1],
    },
}

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=80,
    hidden_dim=256,
    dropout=0.2,
    initial_batch_norm=True,
    # # Sensor layout settings
    # layout_dim=2,
    # layout_proj=True,
    # layout_scaling="minmax",
    # # Merger with spatial attn
    # merger=False,
    # merger_emb_type=None,
    # merger_emb_dim=0,
    # merger_channels=0,
    # merger_dropout=0.0,  # Float
    # merger_conditional=None,
    # Inital
    initial_linear=256,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=4,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.1,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input="continuous",
    transformer_encoder_emb="sinusoidal",
    transformer_encoder_layers=4,
    transformer_encoder_heads=4,
    # Conformer encoder variant
    rnn_type="conformer",
    depthwise_conv_kernel_size=15,
    use_group_norm=False,
    convolution_first=False,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

config = TrainingConfigV2(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Ada lora settings
    # Around 100k total batches an epoch for gwilliams
    use_adalora=True,
    adalora_init_r=12,
    adalora_target_r=4,
    adalora_tinit=(654 * 3),  # 5% total steps
    adalora_tfinal=(654 * 8),  # 50-80% total steps
    adalora_deltaT=(654 * 1),  # 1-5% total steps
    adalora_lora_alpha=32,
    adalora_lora_dropout=0.1,
    adalora_total_step=(654 * 50),
    # Pre-processing parameters
    # Brain
    new_freq=200,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=None,
    baseline_window=0.5,
    notch_filter=True,
    scaling="both",
    delay=0.15,
    # Audio
    audio_model="openai/whisper-tiny.en",
    # Hyperparameters
    learning_rate=1e-4,
    weight_decay=1e-4,
    epochs=50,
    steps_per_epoch=450, # 654,
    batch_size=128,
    random_test_size=10,
    seed=42,
    mel_alignment_objectives = {
        "clip_loss": 0.6,
        "mse_loss": 0.4
    },
    latent_alignment_objectives= {
        "cosine_similarity": 0.4,
        "mse_loss": 0.4,
        "clip_loss": 0.6,
        "mmd_loss": 0.0
    },
    decode_timestamps=True,
)

config.brain_encoder_config.mel_normalization = False
config.learning_rate = 3e-4
config.batch_size = 256
config.steps_per_epoch = 654 # 654
config.epochs = 50

config.brain_encoder_config.in_channels = 269
config.decode_timestamps = True

# Sensor layout settings
config.brain_encoder_config.layout_dim=2
config.brain_encoder_config.layout_proj=False
config.brain_encoder_config.layout_scaling="midpoint"
# Merger with spatial attn
config.brain_encoder_config.merger=True
config.brain_encoder_config.merger_emb_type='mlp'
config.brain_encoder_config.merger_emb_dim=256
config.brain_encoder_config.merger_channels=269
config.brain_encoder_config.merger_dropout=0.1  # Float
config.brain_encoder_config.merger_conditional=None

# config.brain_clipping = 20

# config.brain_encoder_config.hidden_dim = 1024
# config.brain_encoder_config.initial_linear = 1024

session = TrainingSessionV2(
    config=config,
    studies={study: "audiotext" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase3/combining/channel_merger_269_mlp_256_midpoint_2d",
    clear_cache=False,
    cache_name="/home/ubuntu/cache",
    download_studies=True,
)


# session = load_training_session(
#     save_path="saves/phase3/objectives/baseline_gwilliams_latent_loss_no_latent_alignment/epoch_39",
#     studies={"gwilliams2023": "audiotext"},
#     data_path="data",
#     cache_name="/home/ubuntu/cache",
# )

try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=(multiprocessing.cpu_count() - 2),
        max_cache_size=800,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

# try:
#     session.pre_process_all_recordings(
#         buffer_size=30, num_workers=multiprocessing.cpu_count() - 10, max_cache_size=800
#     )
# except KeyboardInterrupt as e:
#     print("Exited")

Loading Gwilliams2023 with batch type audiotext
Loading Armeini2022 with batch type audiotext
Data partitioned on studies ['gwilliams2023', 'armeini2022'].
Train: 159, Unseen Task: 51, Unseen Subject: 12, Unseen Both: 4.

RNNEncoder initialized as conformer with 4 layers, 256 d_model, 4 nhead
	Embedding: sinusoidal, params: 6075392
SimpleConv initialized with 9079658 parameters, cond: ['study', 'subject']
Merger True, merger channels 269
ConvBlocks: 4, hidden_dim: 256, params 2626048
Using torch.bfloat16
Found 64 target modules for AdaLora: ['model.encoder.layers.0.self_attn.k_proj', 'model.encoder.layers.0.self_attn.v_proj', 'model.encoder.layers.0.self_attn.q_proj', 'model.encoder.layers.0.self_attn.out_proj', 'model.encoder.layers.0.fc1', 'model.encoder.layers.0.fc2', 'model.encoder.layers.1.self_attn.k_proj', 'model.encoder.layers.1.self_attn.v_proj', 'model.encoder.layers.1.self_attn.q_proj', 'model.encoder.layers.1.self_attn.out_proj', 'model.encoder.layers.1.fc1', 'model.encoder

2025-03-21 15:56:10,297	INFO worker.py:1841 -- Started a local Ray instance.
Training Epoch 1: 100%|██████████| 159/159 [18:37<00:00,  7.03s/it]


Testing done in 4.54m.
Epoch 1 done in 23.21m. 0.15m/recording.


New best epoch 1 with CER 0.9581 and BLEU 0.0047.
Mel Loss: 6.2077, Clip Loss: 9.5250, MSE: 1.2316
Mel accuracy: 0.0042, Top 5: 0.0226, Top 10: 0.0452


[36m(raylet)[0m Spilled 26412 MiB, 27 objects, write throughput 780 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.
Training Epoch 2: 100%|██████████| 159/159 [18:29<00:00,  6.98s/it]


Testing done in 3.82m.
Epoch 2 done in 22.32m. 0.14m/recording.


New best epoch 2 with CER 0.9536 and BLEU 0.0015.
Mel Loss: 5.8794, Clip Loss: 9.5309, MSE: 0.4023
Mel accuracy: 0.0046, Top 5: 0.0230, Top 10: 0.0458


[36m(raylet)[0m Spilled 53842 MiB, 54 objects, write throughput 781 MiB/s.
Training Epoch 3: 100%|██████████| 159/159 [18:31<00:00,  6.99s/it]


Testing done in 3.32m.
Epoch 3 done in 21.85m. 0.14m/recording.


New best epoch 3 with CER 0.9456 and BLEU 0.0016.
Mel Loss: 5.8035, Clip Loss: 9.5177, MSE: 0.2323
Mel accuracy: 0.0046, Top 5: 0.0226, Top 10: 0.0450


Training Epoch 4:   0%|          | 0/159 [00:00<?, ?it/s]

Starting rank reallocation at recording 1962.


[36m(raylet)[0m Spilled 55867 MiB, 55 objects, write throughput 783 MiB/s.
[36m(raylet)[0m Spilled 80308 MiB, 82 objects, write throughput 850 MiB/s.
Training Epoch 4: 100%|██████████| 159/159 [18:37<00:00,  7.03s/it]


Testing done in 1.17m.
Epoch 4 done in 19.79m. 0.12m/recording.


New best epoch 4 with CER 0.0000 and BLEU 0.0000.
Mel Loss: 5.7935, Clip Loss: 9.5140, MSE: 0.2127
Mel accuracy: 0.0044, Top 5: 0.0228, Top 10: 0.0454


Training Epoch 5: 100%|██████████| 159/159 [18:39<00:00,  7.04s/it]


Testing done in 2.04m.
Epoch 5 done in 20.71m. 0.13m/recording.


New best epoch 5 with CER 0.8177 and BLEU 0.0083.
Mel Loss: 5.8179, Clip Loss: 9.5538, MSE: 0.2142
Mel accuracy: 0.0044, Top 5: 0.0223, Top 10: 0.0448


Training Epoch 6: 100%|██████████| 159/159 [18:40<00:00,  7.05s/it]


Testing done in 3.66m.
Epoch 6 done in 22.34m. 0.14m/recording.


[36m(raylet)[0m Spilled 102205 MiB, 109 objects, write throughput 850 MiB/s.
Training Epoch 7: 100%|██████████| 159/159 [18:42<00:00,  7.06s/it]


Testing done in 4.09m.
Epoch 7 done in 22.80m. 0.14m/recording.


[36m(raylet)[0m Spilled 126956 MiB, 136 objects, write throughput 844 MiB/s.
Training Epoch 8: 100%|██████████| 159/159 [18:39<00:00,  7.04s/it]


Testing done in 3.74m.
Epoch 8 done in 22.40m. 0.14m/recording.


Training Epoch 9: 100%|██████████| 159/159 [18:32<00:00,  7.00s/it]


Testing done in 4.25m.
Epoch 9 done in 22.80m. 0.14m/recording.


Training Epoch 10: 100%|██████████| 159/159 [18:44<00:00,  7.07s/it]


Testing done in 3.08m.
Epoch 10 done in 21.82m. 0.14m/recording.


Training Epoch 11: 100%|██████████| 159/159 [18:43<00:00,  7.07s/it]


Testing done in 3.87m.
Epoch 11 done in 22.60m. 0.14m/recording.


Training Epoch 12: 100%|██████████| 159/159 [18:45<00:00,  7.08s/it]


Testing done in 3.77m.
Epoch 12 done in 22.54m. 0.14m/recording.


Training Epoch 13: 100%|██████████| 159/159 [18:42<00:00,  7.06s/it]


Testing done in 4.99m.
Epoch 13 done in 23.71m. 0.15m/recording.


Training Epoch 14: 100%|██████████| 159/159 [18:36<00:00,  7.02s/it]


Testing done in 5.20m.
Epoch 14 done in 23.81m. 0.15m/recording.


Training Epoch 15: 100%|██████████| 159/159 [18:44<00:00,  7.07s/it]


Testing done in 4.87m.
Epoch 15 done in 23.62m. 0.15m/recording.


Training Epoch 16: 100%|██████████| 159/159 [18:37<00:00,  7.03s/it]


Testing done in 3.98m.
Epoch 16 done in 22.60m. 0.14m/recording.


Training Epoch 17: 100%|██████████| 159/159 [18:49<00:00,  7.11s/it]


Testing done in 4.05m.
Epoch 17 done in 22.89m. 0.14m/recording.


[36m(raylet)[0m Spilled 151327 MiB, 163 objects, write throughput 841 MiB/s.
Training Epoch 18: 100%|██████████| 159/159 [18:48<00:00,  7.10s/it]


Testing done in 4.98m.
Epoch 18 done in 23.80m. 0.15m/recording.


Training Epoch 19: 100%|██████████| 159/159 [18:46<00:00,  7.09s/it]


Testing done in 5.05m.
Epoch 19 done in 23.84m. 0.15m/recording.


Training Epoch 20: 100%|██████████| 159/159 [18:51<00:00,  7.12s/it]


Testing done in 4.91m.
Epoch 20 done in 23.78m. 0.15m/recording.


Training Epoch 21: 100%|██████████| 159/159 [18:52<00:00,  7.12s/it]


Testing done in 5.26m.
Epoch 21 done in 24.14m. 0.15m/recording.


Training Epoch 22: 100%|██████████| 159/159 [18:52<00:00,  7.12s/it]


Testing done in 4.32m.
Epoch 22 done in 23.19m. 0.15m/recording.


Training Epoch 23: 100%|██████████| 159/159 [19:39<00:00,  7.42s/it]


Testing done in 4.51m.
Epoch 23 done in 24.18m. 0.15m/recording.


Training Epoch 24: 100%|██████████| 159/159 [18:39<00:00,  7.04s/it]


Testing done in 4.80m.
Epoch 24 done in 23.45m. 0.15m/recording.


Training Epoch 25: 100%|██████████| 159/159 [18:34<00:00,  7.01s/it]


Testing done in 4.78m.
Epoch 25 done in 23.35m. 0.15m/recording.


Training Epoch 26: 100%|██████████| 159/159 [18:31<00:00,  6.99s/it]


Testing done in 4.49m.
Epoch 26 done in 23.02m. 0.14m/recording.


Training Epoch 27: 100%|██████████| 159/159 [18:33<00:00,  7.00s/it]


Testing done in 4.62m.
Epoch 27 done in 23.17m. 0.15m/recording.


[36m(raylet)[0m Spilled 275456 MiB, 294 objects, write throughput 846 MiB/s.
Training Epoch 28: 100%|██████████| 159/159 [18:54<00:00,  7.14s/it]


Testing done in 4.70m.
Epoch 28 done in 23.61m. 0.15m/recording.


Training Epoch 29: 100%|██████████| 159/159 [18:48<00:00,  7.10s/it]


Testing done in 5.17m.
Epoch 29 done in 23.98m. 0.15m/recording.


Training Epoch 30: 100%|██████████| 159/159 [18:44<00:00,  7.07s/it]


Testing done in 4.95m.
Epoch 30 done in 23.69m. 0.15m/recording.


Training Epoch 31: 100%|██████████| 159/159 [18:49<00:00,  7.10s/it]


Testing done in 5.29m.
Epoch 31 done in 24.11m. 0.15m/recording.


Training Epoch 32: 100%|██████████| 159/159 [18:41<00:00,  7.05s/it]


Testing done in 4.63m.
Epoch 32 done in 23.32m. 0.15m/recording.


Training Epoch 33: 100%|██████████| 159/159 [18:54<00:00,  7.14s/it]


Testing done in 4.59m.
Epoch 33 done in 23.51m. 0.15m/recording.


Training Epoch 34: 100%|██████████| 159/159 [18:47<00:00,  7.09s/it]


Testing done in 4.60m.
Epoch 34 done in 23.39m. 0.15m/recording.


Training Epoch 35: 100%|██████████| 159/159 [18:49<00:00,  7.10s/it]


Testing done in 4.53m.
Epoch 35 done in 23.35m. 0.15m/recording.


Training Epoch 36: 100%|██████████| 159/159 [18:51<00:00,  7.12s/it]


Testing done in 4.44m.
Epoch 36 done in 23.30m. 0.15m/recording.


Training Epoch 37: 100%|██████████| 159/159 [18:44<00:00,  7.07s/it]


Testing done in 4.41m.
Epoch 37 done in 23.16m. 0.15m/recording.


Training Epoch 38: 100%|██████████| 159/159 [18:44<00:00,  7.07s/it]


Testing done in 4.25m.
Epoch 38 done in 23.00m. 0.14m/recording.


Training Epoch 39: 100%|██████████| 159/159 [18:43<00:00,  7.06s/it]


Testing done in 4.57m.
Epoch 39 done in 23.30m. 0.15m/recording.


Training Epoch 40: 100%|██████████| 159/159 [18:50<00:00,  7.11s/it]


Testing done in 4.68m.
Epoch 40 done in 23.54m. 0.15m/recording.


Training Epoch 41: 100%|██████████| 159/159 [18:47<00:00,  7.09s/it]


Testing done in 4.66m.
Epoch 41 done in 23.46m. 0.15m/recording.


Training Epoch 42: 100%|██████████| 159/159 [18:42<00:00,  7.06s/it]


Testing done in 4.86m.
Epoch 42 done in 23.56m. 0.15m/recording.


Training Epoch 43: 100%|██████████| 159/159 [18:42<00:00,  7.06s/it]


Testing done in 4.80m.
Epoch 43 done in 23.51m. 0.15m/recording.


Training Epoch 44: 100%|██████████| 159/159 [18:43<00:00,  7.07s/it]


Testing done in 4.80m.
Epoch 44 done in 23.53m. 0.15m/recording.


Training Epoch 45: 100%|██████████| 159/159 [18:42<00:00,  7.06s/it]


Testing done in 4.87m.
Epoch 45 done in 23.58m. 0.15m/recording.


Training Epoch 46: 100%|██████████| 159/159 [18:43<00:00,  7.07s/it]


Testing done in 4.78m.
Epoch 46 done in 23.51m. 0.15m/recording.


Training Epoch 47: 100%|██████████| 159/159 [18:46<00:00,  7.08s/it]


Testing done in 4.50m.
Epoch 47 done in 23.28m. 0.15m/recording.


[36m(raylet)[0m Spilled 539367 MiB, 577 objects, write throughput 851 MiB/s.
Training Epoch 48: 100%|██████████| 159/159 [18:46<00:00,  7.08s/it]


Testing done in 4.56m.
Epoch 48 done in 23.34m. 0.15m/recording.


Training Epoch 49: 100%|██████████| 159/159 [18:42<00:00,  7.06s/it]


Testing done in 4.56m.
Epoch 49 done in 23.27m. 0.15m/recording.


Training Epoch 50: 100%|██████████| 159/159 [18:39<00:00,  7.04s/it]


Testing done in 4.69m.
Epoch 50 done in 23.35m. 0.15m/recording.


Training completed. Highest epoch at 5.


Test unseen_subject at epoch 5. Mel Loss: 5.7840, Clip Loss: 9.4633, MSE: 0.2650
Mel accuracy: 0.0040, Top 5: 0.0230, Top 10: 0.0458
BLEU: 0.0065, ROUGE-1: 0.0670, BERT: 0.3491, CER: 0.8161, SELF-BLEU: 0.4986


Test unseen_task at epoch 5. Mel Loss: 5.7719, Clip Loss: 9.4537, MSE: 0.2492
Mel accuracy: 0.0053, Top 5: 0.0249, Top 10: 0.0494
BLEU: 0.0060, ROUGE-1: 0.0558, BERT: 0.3662, CER: 0.8124, SELF-BLEU: 0.4274


Test unseen_both at epoch 5. Mel Loss: 5.7670, Clip Loss: 9.4447, MSE: 0.2505
Mel accuracy: 0.0037, Top 5: 0.0209, Top 10: 0.0468
BLEU: 0.0122, ROUGE-1: 0.1123, BERT: 0.3879, CER: 0.8247, SELF-BLEU: 0.5623


In [2]:
import torch
import gc
import logging

# del session.logger
# del session.epoch_logger
# del session
# gc.collect()
# torch.cuda.empty_cache()

import multiprocessing
from train.training_session_v2 import TrainingSessionV2, load_training_session
from config import TrainingConfigV2
from config import SimpleConvConfig

data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    "armeini2022": {
        "testing_subjects": [],
        "testing_tasks": [0, 1],
    },
}

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=80,
    hidden_dim=256,
    dropout=0.2,
    initial_batch_norm=True,
    # # Sensor layout settings
    # layout_dim=2,
    # layout_proj=True,
    # layout_scaling="minmax",
    # # Merger with spatial attn
    # merger=False,
    # merger_emb_type=None,
    # merger_emb_dim=0,
    # merger_channels=0,
    # merger_dropout=0.0,  # Float
    # merger_conditional=None,
    # Inital
    initial_linear=256,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=4,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.1,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input="continuous",
    transformer_encoder_emb="sinusoidal",
    transformer_encoder_layers=4,
    transformer_encoder_heads=4,
    # Conformer encoder variant
    rnn_type="conformer",
    depthwise_conv_kernel_size=15,
    use_group_norm=False,
    convolution_first=False,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

config = TrainingConfigV2(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Ada lora settings
    # Around 100k total batches an epoch for gwilliams
    use_adalora=True,
    adalora_init_r=12,
    adalora_target_r=4,
    adalora_tinit=(654 * 3),  # 5% total steps
    adalora_tfinal=(654 * 8),  # 50-80% total steps
    adalora_deltaT=(654 * 1),  # 1-5% total steps
    adalora_lora_alpha=32,
    adalora_lora_dropout=0.1,
    adalora_total_step=(654 * 50),
    # Pre-processing parameters
    # Brain
    new_freq=200,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=None,
    baseline_window=0.5,
    notch_filter=True,
    scaling="both",
    delay=0.15,
    # Audio
    audio_model="openai/whisper-tiny.en",
    # Hyperparameters
    learning_rate=1e-4,
    weight_decay=1e-4,
    epochs=50,
    steps_per_epoch=450, # 654,
    batch_size=128,
    random_test_size=10,
    seed=42,
    mel_alignment_objectives = {
        "clip_loss": 0.6,
        "mse_loss": 0.4
    },
    latent_alignment_objectives= {
        "cosine_similarity": 0.4,
        "mse_loss": 0.4,
        "clip_loss": 0.6,
        "mmd_loss": 0.0
    },
    decode_timestamps=True,
)

config.brain_encoder_config.mel_normalization = False
config.learning_rate = 3e-4
config.batch_size = 256
config.steps_per_epoch = 654 # 654
config.epochs = 50

config.brain_encoder_config.in_channels = 269
config.decode_timestamps = True

# Sensor layout settings
config.brain_encoder_config.layout_dim=3
config.brain_encoder_config.layout_proj=False
config.brain_encoder_config.layout_scaling="midpoint"
# Merger with spatial attn
config.brain_encoder_config.merger=True
config.brain_encoder_config.merger_emb_type='mlp'
config.brain_encoder_config.merger_emb_dim=256
config.brain_encoder_config.merger_channels=269
config.brain_encoder_config.merger_dropout=0.1  # Float
config.brain_encoder_config.merger_conditional=None

# config.brain_clipping = 20

# config.brain_encoder_config.hidden_dim = 1024
# config.brain_encoder_config.initial_linear = 1024

session = TrainingSessionV2(
    config=config,
    studies={study: "audiotext" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase3/combining/channel_merger_269_mlp_256_midpoint_3d",
    clear_cache=False,
    cache_name="/home/ubuntu/cache",
    download_studies=True,
)


# session = load_training_session(
#     save_path="saves/phase3/objectives/baseline_gwilliams_latent_loss_no_latent_alignment/epoch_39",
#     studies={"gwilliams2023": "audiotext"},
#     data_path="data",
#     cache_name="/home/ubuntu/cache",
# )

try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=(multiprocessing.cpu_count() - 2),
        max_cache_size=800,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

# try:
#     session.pre_process_all_recordings(
#         buffer_size=30, num_workers=multiprocessing.cpu_count() - 10, max_cache_size=800
#     )
# except KeyboardInterrupt as e:
#     print("Exited")

Loading Gwilliams2023 with batch type audiotext
Loading Armeini2022 with batch type audiotext
Data partitioned on studies ['gwilliams2023', 'armeini2022'].
Train: 159, Unseen Task: 51, Unseen Subject: 12, Unseen Both: 4.

RNNEncoder initialized as conformer with 4 layers, 256 d_model, 4 nhead
	Embedding: sinusoidal, params: 6075392
SimpleConv initialized with 9079914 parameters, cond: ['study', 'subject']
Merger True, merger channels 269
ConvBlocks: 4, hidden_dim: 256, params 2626048
Using torch.bfloat16
Found 64 target modules for AdaLora: ['model.encoder.layers.0.self_attn.k_proj', 'model.encoder.layers.0.self_attn.v_proj', 'model.encoder.layers.0.self_attn.q_proj', 'model.encoder.layers.0.self_attn.out_proj', 'model.encoder.layers.0.fc1', 'model.encoder.layers.0.fc2', 'model.encoder.layers.1.self_attn.k_proj', 'model.encoder.layers.1.self_attn.v_proj', 'model.encoder.layers.1.self_attn.q_proj', 'model.encoder.layers.1.self_attn.out_proj', 'model.encoder.layers.1.fc1', 'model.encoder

Training Epoch 1: 100%|██████████| 159/159 [20:03<00:00,  7.57s/it]


Testing done in 8.21m.
Epoch 1 done in 28.28m. 0.18m/recording.


New best epoch 1 with CER 4.0195 and BLEU 0.0037.
Mel Loss: 6.2299, Clip Loss: 9.5205, MSE: 1.2939
Mel accuracy: 0.0043, Top 5: 0.0226, Top 10: 0.0448


Training Epoch 2: 100%|██████████| 159/159 [19:14<00:00,  7.26s/it]


Testing done in 1.30m.
Epoch 2 done in 20.54m. 0.13m/recording.


New best epoch 2 with CER 0.0000 and BLEU 0.0000.
Mel Loss: 5.8909, Clip Loss: 9.5206, MSE: 0.4463
Mel accuracy: 0.0046, Top 5: 0.0226, Top 10: 0.0448


Training Epoch 3: 100%|██████████| 159/159 [19:00<00:00,  7.18s/it]


Testing done in 3.50m.
Epoch 3 done in 22.52m. 0.14m/recording.


New best epoch 3 with CER 0.9459 and BLEU 0.0013.
Mel Loss: 5.8144, Clip Loss: 9.5157, MSE: 0.2624
Mel accuracy: 0.0045, Top 5: 0.0222, Top 10: 0.0447


Training Epoch 4:   0%|          | 0/159 [00:00<?, ?it/s]

Starting rank reallocation at recording 1962.


Training Epoch 4: 100%|██████████| 159/159 [19:17<00:00,  7.28s/it]


Testing done in 3.39m.
Epoch 4 done in 22.69m. 0.14m/recording.


New best epoch 4 with CER 0.9455 and BLEU 0.0016.
Mel Loss: 5.8038, Clip Loss: 9.5304, MSE: 0.2138
Mel accuracy: 0.0047, Top 5: 0.0224, Top 10: 0.0447


Training Epoch 5: 100%|██████████| 159/159 [19:17<00:00,  7.28s/it]


Testing done in 4.01m.
Epoch 5 done in 23.30m. 0.15m/recording.


Training Epoch 6: 100%|██████████| 159/159 [19:18<00:00,  7.29s/it]


Testing done in 4.51m.
Epoch 6 done in 23.82m. 0.15m/recording.


New best epoch 6 with CER 0.9421 and BLEU 0.0043.
Mel Loss: 5.8590, Clip Loss: 9.5693, MSE: 0.2937
Mel accuracy: 0.0110, Top 5: 0.0454, Top 10: 0.0833


Training Epoch 7: 100%|██████████| 159/159 [19:21<00:00,  7.30s/it]


Testing done in 3.63m.
Epoch 7 done in 22.99m. 0.14m/recording.


Training Epoch 8: 100%|██████████| 159/159 [19:06<00:00,  7.21s/it]


Testing done in 3.83m.
Epoch 8 done in 22.94m. 0.14m/recording.


Training Epoch 9: 100%|██████████| 159/159 [19:16<00:00,  7.27s/it]


Testing done in 3.93m.
Epoch 9 done in 23.21m. 0.15m/recording.


Training Epoch 10: 100%|██████████| 159/159 [19:33<00:00,  7.38s/it]


Testing done in 3.88m.
Epoch 10 done in 23.44m. 0.15m/recording.


Training Epoch 11: 100%|██████████| 159/159 [19:33<00:00,  7.38s/it]


Testing done in 3.44m.
Epoch 11 done in 23.00m. 0.14m/recording.


Training Epoch 12: 100%|██████████| 159/159 [19:34<00:00,  7.39s/it]


Testing done in 4.44m.
Epoch 12 done in 24.01m. 0.15m/recording.


New best epoch 12 with CER 0.9345 and BLEU 0.0045.
Mel Loss: 5.9132, Clip Loss: 9.5278, MSE: 0.4913
Mel accuracy: 0.0460, Top 5: 0.1475, Top 10: 0.2278


Training Epoch 13: 100%|██████████| 159/159 [19:36<00:00,  7.40s/it]


Testing done in 4.50m.
Epoch 13 done in 24.11m. 0.15m/recording.


Training Epoch 14: 100%|██████████| 159/159 [19:36<00:00,  7.40s/it]


Testing done in 3.90m.
Epoch 14 done in 23.51m. 0.15m/recording.


Training Epoch 15: 100%|██████████| 159/159 [19:34<00:00,  7.38s/it]


Testing done in 4.08m.
Epoch 15 done in 23.65m. 0.15m/recording.


Training Epoch 16: 100%|██████████| 159/159 [19:36<00:00,  7.40s/it]


Testing done in 3.81m.
Epoch 16 done in 23.41m. 0.15m/recording.


Training Epoch 17: 100%|██████████| 159/159 [19:35<00:00,  7.39s/it]


Testing done in 4.10m.
Epoch 17 done in 23.69m. 0.15m/recording.


Training Epoch 18: 100%|██████████| 159/159 [19:37<00:00,  7.41s/it]


Testing done in 4.23m.
Epoch 18 done in 23.86m. 0.15m/recording.


Training Epoch 19: 100%|██████████| 159/159 [19:33<00:00,  7.38s/it]


Testing done in 3.90m.
Epoch 19 done in 23.47m. 0.15m/recording.


Training Epoch 20: 100%|██████████| 159/159 [19:40<00:00,  7.43s/it]


Testing done in 4.02m.
Epoch 20 done in 23.70m. 0.15m/recording.


Training Epoch 21: 100%|██████████| 159/159 [19:36<00:00,  7.40s/it]


Testing done in 4.10m.
Epoch 21 done in 23.71m. 0.15m/recording.


Training Epoch 22: 100%|██████████| 159/159 [19:36<00:00,  7.40s/it]


Testing done in 4.21m.
Epoch 22 done in 23.82m. 0.15m/recording.


Training Epoch 23: 100%|██████████| 159/159 [19:38<00:00,  7.41s/it]


Testing done in 4.25m.
Epoch 23 done in 23.89m. 0.15m/recording.
Early stopping at epoch 23. Highest metrics at epoch 12.


Training completed. Highest epoch at 12.


Test unseen_subject at epoch 12. Mel Loss: 5.8548, Clip Loss: 9.4214, MSE: 0.5047
Mel accuracy: 0.0463, Top 5: 0.1432, Top 10: 0.2430
BLEU: 0.0069, ROUGE-1: 0.0478, BERT: 0.3626, CER: 0.9433, SELF-BLEU: 0.5101


Test unseen_task at epoch 12. Mel Loss: 5.8786, Clip Loss: 9.4680, MSE: 0.4946
Mel accuracy: 0.0431, Top 5: 0.1208, Top 10: 0.1947
BLEU: 0.0034, ROUGE-1: 0.0260, BERT: 0.3591, CER: 0.9345, SELF-BLEU: 0.4907


Test unseen_both at epoch 12. Mel Loss: 5.8590, Clip Loss: 9.4356, MSE: 0.4940
Mel accuracy: 0.0505, Top 5: 0.1256, Top 10: 0.2488
BLEU: 0.0033, ROUGE-1: 0.0220, BERT: 0.3571, CER: 0.9256, SELF-BLEU: 0.5224
