In [None]:
# del session.logger
# del session

import multiprocessing
from train.training_session_v1 import TrainingSessionV1
from config import TrainingConfigV1
from config import SimpleConvConfig

data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "armeini2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=80,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=0.0,  # Float
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Conformer encoder variant
    rnn_type="conformer",
    depthwise_conv_kernel_size=31,
    use_group_norm=True,
    convolution_first=False,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

config = TrainingConfigV1(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Ada lora settings
    # Around 100k total batches an epoch for gwilliams
    adalora_init_r=12,
    adalora_target_r=4,
    adalora_tinit=1600,  # 5% total steps
    adalora_tfinal=25600,  # 50-80% total steps
    adalora_deltaT=1600,  # 1-5% total steps
    adalora_lora_alpha=32,
    adalora_lora_dropout=0.1,
    adalora_total_step=32000,
    # Pre-processing parameters
    # Brain
    new_freq=200,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=None,
    baseline_window=0.5,
    notch_filter=True,
    scaling="both",
    delay=0.15,
    # Audio
    audio_model="openai/whisper-tiny.en",
    # Hyperparameters
    learning_rate=3e-4,
    weight_decay=1e-4,
    epochs=40,
    batch_size=128,
    random_test_size=10,
    seed=42,
    mel_alignment_objectives={
        "clip_loss": 0.6,
        "mse_loss": 0.4,
        "commitment_loss": 0.0,
    },
    latent_alignment_objectives={
        "cosine_similarity": 0.4,
        "mse_loss": 0.4,
        "clip_loss": 0.2,
    },
    latent_alignment_layers=[-1],
)

config.brain_encoder_config.mel_normalization = False
config.learning_rate = 5e-4
config.batch_size = 128


session = TrainingSessionV1(
    config=config,
    studies={study: "audio" for study in data_partition.keys()},
    data_path="/home/ubuntu/storage/data",
    save_path="saves/phase2/objectives/CLIP_MSE_TEST",
    clear_cache=False,
    cache_name="cache",
    download_studies=False,
)


try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=(multiprocessing.cpu_count() - 2),
        max_cache_size=400,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

# try:
#     session.pre_process_all_recordings(
#         buffer_size=30, num_workers=multiprocessing.cpu_count() - 2, max_cache_size=400
#     )
# except KeyboardInterrupt as e:
#     print("Exited")

Loading Gwilliams2023 with batch type audio
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.

SimpleConv initialized with 9295984 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 6, hidden_dim: 384, params 8858112
Found 24 target modules for AdaLora: ['k_proj', 'q_proj', 'v_proj', 'out_proj', 'fc1', 'fc2']
openai/whisper-tiny.en loaded with 8540472 frozen params (4 layers and 384) dim.
AdaLora has 332064 trainable params.


2025-02-04 16:37:17,751	INFO worker.py:1841 -- Started a local Ray instance.
Training Epoch 1: 100%|██████████| 135/135 [13:31<00:00,  6.01s/it]


Epoch 1, Loss: 4.7548, Mel Loss: 3.3012
Clip Loss: 4.9186, MSE Loss: 0.8752, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0164, Top 5 Accuracy: 0.0708, Top 10 Accuracy: 0.1295
Final Layer Clip Loss: 5.2921, Final Layer MSE Loss: 0.7067, Final Layer Cosine Similarity Loss: 0.2812, Final Layer Total Loss: 1.4535
Test unseen_subject completed., Loss: 4.1261, Mel Loss: 2.8718
Clip Loss: 4.4794, MSE Loss: 0.4605, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0575, Top 5 Accuracy: 0.1721, Top 10 Accuracy: 0.2678
Final Layer Clip Loss: 4.6031, Final Layer MSE Loss: 0.5928, Final Layer Cosine Similarity Loss: 0.2414, Final Layer Total Loss: 1.2543
Test unseen_task completed., Loss: 4.1174, Mel Loss: 2.8460
Clip Loss: 4.4299, MSE Loss: 0.4702, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0442, Top 5 Accuracy: 0.1623, Top 10 Accuracy: 0.2638
Final Layer Clip Loss: 4.6688, Final Layer MSE Loss: 0.6007, Final Layer Cosine Similarity Loss: 0.2433, Final Layer Total L

Training Epoch 2: 100%|██████████| 135/135 [12:40<00:00,  5.63s/it]


Epoch 2, Loss: 7.0912, Mel Loss: 5.7939
Clip Loss: 5.1847, MSE Loss: 6.7077, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0228, Top 5 Accuracy: 0.0860, Top 10 Accuracy: 0.1483
Final Layer Clip Loss: 4.6552, Final Layer MSE Loss: 0.6542, Final Layer Cosine Similarity Loss: 0.2615, Final Layer Total Loss: 1.2973
Test unseen_subject completed., Loss: 4.2838, Mel Loss: 3.0145
Clip Loss: 4.6221, MSE Loss: 0.6031, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0104, Top 5 Accuracy: 0.0494, Top 10 Accuracy: 0.0998
Final Layer Clip Loss: 4.7081, Final Layer MSE Loss: 0.5830, Final Layer Cosine Similarity Loss: 0.2362, Final Layer Total Loss: 1.2693
Test unseen_task completed., Loss: 4.3939, Mel Loss: 3.0961
Clip Loss: 4.7520, MSE Loss: 0.6121, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0086, Top 5 Accuracy: 0.0441, Top 10 Accuracy: 0.0889
Final Layer Clip Loss: 4.8385, Final Layer MSE Loss: 0.5884, Final Layer Cosine Similarity Loss: 0.2370, Final Layer Total L

Training Epoch 3: 100%|██████████| 135/135 [12:26<00:00,  5.53s/it]


Epoch 3, Loss: 4.3287, Mel Loss: 3.0515
Clip Loss: 4.7923, MSE Loss: 0.4405, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0091, Top 5 Accuracy: 0.0460, Top 10 Accuracy: 0.0925
Final Layer Clip Loss: 4.7480, Final Layer MSE Loss: 0.5835, Final Layer Cosine Similarity Loss: 0.2355, Final Layer Total Loss: 1.2772
Test unseen_subject completed., Loss: 4.4306, Mel Loss: 3.1959
Clip Loss: 5.1120, MSE Loss: 0.3218, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0109, Top 5 Accuracy: 0.0563, Top 10 Accuracy: 0.1047
Final Layer Clip Loss: 4.5629, Final Layer MSE Loss: 0.5727, Final Layer Cosine Similarity Loss: 0.2325, Final Layer Total Loss: 1.2347
Test unseen_task completed., Loss: 4.5808, Mel Loss: 3.3195
Clip Loss: 5.3119, MSE Loss: 0.3309, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0089, Top 5 Accuracy: 0.0464, Top 10 Accuracy: 0.0906
Final Layer Clip Loss: 4.6826, Final Layer MSE Loss: 0.5785, Final Layer Cosine Similarity Loss: 0.2335, Final Layer Total L

Training Epoch 4: 100%|██████████| 135/135 [12:34<00:00,  5.59s/it]


Epoch 4, Loss: 4.2636, Mel Loss: 3.0052
Clip Loss: 4.7829, MSE Loss: 0.3387, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0108, Top 5 Accuracy: 0.0516, Top 10 Accuracy: 0.0990
Final Layer Clip Loss: 4.6652, Final Layer MSE Loss: 0.5796, Final Layer Cosine Similarity Loss: 0.2337, Final Layer Total Loss: 1.2584
Test unseen_subject completed., Loss: 4.1298, Mel Loss: 2.8818
Clip Loss: 4.6257, MSE Loss: 0.2660, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0138, Top 5 Accuracy: 0.0568, Top 10 Accuracy: 0.1022
Final Layer Clip Loss: 4.6317, Final Layer MSE Loss: 0.5718, Final Layer Cosine Similarity Loss: 0.2324, Final Layer Total Loss: 1.2480
Test unseen_task completed., Loss: 4.2451, Mel Loss: 2.9679
Clip Loss: 4.7620, MSE Loss: 0.2768, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0090, Top 5 Accuracy: 0.0466, Top 10 Accuracy: 0.0896
Final Layer Clip Loss: 4.7627, Final Layer MSE Loss: 0.5781, Final Layer Cosine Similarity Loss: 0.2336, Final Layer Total L

Training Epoch 5: 100%|██████████| 135/135 [12:38<00:00,  5.62s/it]


Epoch 5, Loss: 4.1957, Mel Loss: 2.9654
Clip Loss: 4.7246, MSE Loss: 0.3265, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0175, Top 5 Accuracy: 0.0752, Top 10 Accuracy: 0.1354
Final Layer Clip Loss: 4.5172, Final Layer MSE Loss: 0.5823, Final Layer Cosine Similarity Loss: 0.2348, Final Layer Total Loss: 1.2303
Test unseen_subject completed., Loss: 4.0409, Mel Loss: 2.8349
Clip Loss: 4.5270, MSE Loss: 0.2969, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0297, Top 5 Accuracy: 0.1137, Top 10 Accuracy: 0.1936
Final Layer Clip Loss: 4.3921, Final Layer MSE Loss: 0.5824, Final Layer Cosine Similarity Loss: 0.2364, Final Layer Total Loss: 1.2059
Test unseen_task completed., Loss: 4.0902, Mel Loss: 2.8915
Clip Loss: 4.6148, MSE Loss: 0.3065, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0259, Top 5 Accuracy: 0.1010, Top 10 Accuracy: 0.1766
Final Layer Clip Loss: 4.3422, Final Layer MSE Loss: 0.5884, Final Layer Cosine Similarity Loss: 0.2374, Final Layer Total L

Training Epoch 6: 100%|██████████| 135/135 [12:34<00:00,  5.59s/it]


Epoch 6, Loss: 15.8948, Mel Loss: 14.6301
Clip Loss: 5.4496, MSE Loss: 28.4010, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0248, Top 5 Accuracy: 0.0903, Top 10 Accuracy: 0.1559
Final Layer Clip Loss: 4.5624, Final Layer MSE Loss: 0.6297, Final Layer Cosine Similarity Loss: 0.2508, Final Layer Total Loss: 1.2647
Test unseen_subject completed., Loss: 4.2552, Mel Loss: 3.0145
Clip Loss: 4.6442, MSE Loss: 0.5699, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0084, Top 5 Accuracy: 0.0474, Top 10 Accuracy: 0.0988
Final Layer Clip Loss: 4.5791, Final Layer MSE Loss: 0.5778, Final Layer Cosine Similarity Loss: 0.2344, Final Layer Total Loss: 1.2407
Test unseen_task completed., Loss: 4.3565, Mel Loss: 3.0988
Clip Loss: 4.7856, MSE Loss: 0.5687, Commitment Loss: 0.0000
Perplexity: 0.0000, Accuracy: 0.0102, Top 5 Accuracy: 0.0452, Top 10 Accuracy: 0.0901
Final Layer Clip Loss: 4.6491, Final Layer MSE Loss: 0.5841, Final Layer Cosine Similarity Loss: 0.2356, Final Layer Tota

Training Epoch 7:   0%|          | 0/135 [00:00<?, ?it/s]

In [None]:
# import torch

# torch.cuda.get_device_capability()

In [None]:
# for name, param in session.model.encoder.named_parameters():
#     if "lora_A" in name or "lora_B" in name or "lora_E" in name:
#         print(name)

In [None]:
# import time

# dataloader = session.get_dataloader(buffer_size=30, num_workers=24, max_cache_size=400)
# dataloader.start_fetching(session.dataset["train"], cache=True)

# # Process batches as they become available
# try:
#     batches, recs, start_time = 0, 0, time.time()
#     print(f"Total recordings: {len(session.dataset['train'])}")

#     while True:
#         batch = dataloader.get_recording()

#         if batch is None:
#             break

#         brain = batch.brain_segments["all"]
#         batches += brain.shape[0]
#         recs += 1

#         if recs % 10 == 0:
#             print(f'Processed {recs} recordings of {len(session.dataset["train"])}')

#     print(
#         f"Batch {batches} ({recs} recordings) processed in {time.time() - start_time:.2f}s"
#     )
#     print(
#         f"Average processing time per recording: {(time.time() - start_time) / recs:.2f}s"
#     )
#     print(
#         f"Average processing time per batch: {(time.time() - start_time) / batches:.2f}s"
#     )

# except KeyboardInterrupt:
#     print("Interrupted")
#     dataloader.stop()
# except Exception as e:
#     print("Error", e)
#     dataloader.stop()