In [None]:
# del session.logger
# del session

import multiprocessing


from train.training_session_v0 import TrainingSessionV0
from config import TrainingConfigV0
from config import SimpleConvConfig

data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "armeini2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=False,
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

config = TrainingConfigV0(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Pre-processing parameters
    # Brain
    new_freq=200,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=None,
    baseline_window=0.5,
    notch_filter=True,
    scaling="both",
    delay=0.15,
    # Hyperparameters
    learning_rate=1e-4,
    weight_decay=1e-4,
    epochs=50,
    batch_size=64,
    use_clip_loss=True,
    use_mse_loss=True,
    alpha=0.6,
    random_test_size=10,
    seed=42,
)

session = TrainingSessionV0(
    config=config,
    studies={study: "audio" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase1/objectives/CLIP_MSE",
    clear_cache=True,
    cache_name="cache/1",
)

try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=(multiprocessing.cpu_count() - 2),
        max_cache_size=400,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

Loading Gwilliams2023 with batch type audio
Cleared cache for study gwilliams2023
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.

SimpleConv initialized with 9332896 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 6, hidden_dim: 384, params 8858112


2025-01-25 20:51:58,493	INFO worker.py:1841 -- Started a local Ray instance.
Training Epoch 1: 100%|██████████| 135/135 [21:29<00:00,  9.55s/it] 


Epoch 1 completed. Loss: 6.1302, Clip Loss: 10.0194, MSE Loss: 0.2965
Accuracy: 0.0501, Top 5: 0.2340, Top 10: 0.4485
Test unseen_subject completed. Accuracy: 0.1035, Top 5: 0.4630, Top 10: 0.8840
Test unseen_task completed. Accuracy: 0.0654, Top 5: 0.3216, Top 10: 0.6036
Test unseen_both completed. Accuracy: 0.1062, Top 5: 0.4763, Top 10: 0.9213
Testing completed in 3.63m.
Epoch 1 completed in 25.16m. 0.19m per recording.


Training Epoch 2: 100%|██████████| 135/135 [08:21<00:00,  3.71s/it]


Epoch 2 completed. Loss: 951.9758, Clip Loss: 10.0542, MSE Loss: 2364.8580
Accuracy: 0.0427, Top 5: 0.1965, Top 10: 0.3744
Test unseen_subject completed. Accuracy: 0.0710, Top 5: 0.3220, Top 10: 0.5945
Test unseen_task completed. Accuracy: 0.0518, Top 5: 0.2288, Top 10: 0.4216
Test unseen_both completed. Accuracy: 0.0712, Top 5: 0.2963, Top 10: 0.5600
Testing completed in 0.76m.
Epoch 2 completed in 9.12m. 0.07m per recording.


Training Epoch 3: 100%|██████████| 135/135 [08:20<00:00,  3.71s/it]


Epoch 3 completed. Loss: 6.2510, Clip Loss: 10.0625, MSE Loss: 0.5336
Accuracy: 0.0424, Top 5: 0.1977, Top 10: 0.3726
Test unseen_subject completed. Accuracy: 0.0510, Top 5: 0.2510, Top 10: 0.4460
Test unseen_task completed. Accuracy: 0.0357, Top 5: 0.1708, Top 10: 0.3307
Test unseen_both completed. Accuracy: 0.0350, Top 5: 0.2137, Top 10: 0.4300
Testing completed in 0.79m.
Epoch 3 completed in 9.13m. 0.07m per recording.


Training Epoch 4: 100%|██████████| 135/135 [08:21<00:00,  3.71s/it]


Epoch 4 completed. Loss: 6.5258, Clip Loss: 10.0669, MSE Loss: 1.2142
Accuracy: 0.0414, Top 5: 0.1882, Top 10: 0.3542
Test unseen_subject completed. Accuracy: 0.0560, Top 5: 0.2610, Top 10: 0.5230
Test unseen_task completed. Accuracy: 0.0454, Top 5: 0.2046, Top 10: 0.3791
Test unseen_both completed. Accuracy: 0.0587, Top 5: 0.2750, Top 10: 0.5225
Testing completed in 0.78m.
Epoch 4 completed in 9.13m. 0.07m per recording.


Training Epoch 5: 100%|██████████| 135/135 [08:18<00:00,  3.69s/it]


Epoch 5 completed. Loss: 6.2484, Clip Loss: 10.0628, MSE Loss: 0.5268
Accuracy: 0.0425, Top 5: 0.1918, Top 10: 0.3641
Test unseen_subject completed. Accuracy: 0.0665, Top 5: 0.2915, Top 10: 0.5250
Test unseen_task completed. Accuracy: 0.0446, Top 5: 0.2022, Top 10: 0.3825
Test unseen_both completed. Accuracy: 0.0737, Top 5: 0.3150, Top 10: 0.5475
Testing completed in 0.78m.
Epoch 5 completed in 9.09m. 0.07m per recording.


Training Epoch 6:  44%|████▍     | 60/135 [04:25<04:52,  3.89s/it] 