In [None]:
# del session.logger
# del session

import multiprocessing


from train.training_session_v0 import TrainingSessionV0
from config import TrainingConfigV0
from config import SimpleConvConfig

data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "armeini2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=False,
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

config = TrainingConfigV0(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Pre-processing parameters
    # Brain
    new_freq=200,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=None,
    baseline_window=0.5,
    notch_filter=True,
    scaling="both",
    delay=0.15,
    # Hyperparameters
    learning_rate=5e-4,
    weight_decay=1e-4,
    epochs=50,
    batch_size=256,
    use_clip_loss=True,
    use_mse_loss=True,
    alpha=0.6,
    random_test_size=10,
    seed=42,
)

config.learning_rate = 5e-4
config.batch_size=64
config.use_mse_loss = False
config.alpha = 0.0

session = TrainingSessionV0(
    config=config,
    studies={study: "audio" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase1/objectives/CLIP_BT",
    clear_cache=False,
    cache_name="cache/1",
)

try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=(multiprocessing.cpu_count() - 2),
        max_cache_size=400,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

Loading Gwilliams2023 with batch type audio
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.

SimpleConv initialized with 9332896 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 6, hidden_dim: 384, params 8858112


2025-01-26 00:58:30,331	INFO worker.py:1841 -- Started a local Ray instance.
[36m(raylet)[0m Spilled 4046 MiB, 5 objects, write throughput 554 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.
[36m(raylet)[0m Spilled 4723 MiB, 6 objects, write throughput 559 MiB/s.
[36m(raylet)[0m Spilled 22104 MiB, 32 objects, write throughput 982 MiB/s.
[36m(raylet)[0m Spilled 29995 MiB, 43 objects, write throughput 794 MiB/s.
Training Epoch 1: 100%|██████████| 135/135 [08:55<00:00,  3.97s/it]


Epoch 1 completed. Loss: 9.9831, Clip Loss: 9.9831, MSE Loss: 858426486020.1595
Accuracy: 0.0001, Top 5: 0.0007, Top 10: 0.0014
Test unseen_subject completed. Accuracy: 0.0002, Top 5: 0.0010, Top 10: 0.0020
Test unseen_task completed. Accuracy: 0.0002, Top 5: 0.0008, Top 10: 0.0015
Test unseen_both completed. Accuracy: 0.0002, Top 5: 0.0010, Top 10: 0.0020
Testing completed in 0.79m.
Epoch 1 completed in 9.75m. 0.07m per recording.


[36m(raylet)[0m Spilled 51016 MiB, 75 objects, write throughput 900 MiB/s.
Training Epoch 2: 100%|██████████| 135/135 [08:48<00:00,  3.92s/it]


Epoch 2 completed. Loss: 9.9441, Clip Loss: 9.9441, MSE Loss: 75705303009498.1094
Accuracy: 0.0002, Top 5: 0.0009, Top 10: 0.0018
Test unseen_subject completed. Accuracy: 0.0002, Top 5: 0.0010, Top 10: 0.0021
Test unseen_task completed. Accuracy: 0.0002, Top 5: 0.0009, Top 10: 0.0017
Test unseen_both completed. Accuracy: 0.0002, Top 5: 0.0011, Top 10: 0.0023
Testing completed in 0.77m.
Epoch 2 completed in 9.58m. 0.07m per recording.


[36m(raylet)[0m Spilled 66339 MiB, 98 objects, write throughput 864 MiB/s.
Training Epoch 3: 100%|██████████| 135/135 [08:56<00:00,  3.98s/it]


Epoch 3 completed. Loss: 9.8836, Clip Loss: 9.8836, MSE Loss: 266262851674924.3125
Accuracy: 0.0003, Top 5: 0.0012, Top 10: 0.0024
Test unseen_subject completed. Accuracy: 0.0003, Top 5: 0.0016, Top 10: 0.0032
Test unseen_task completed. Accuracy: 0.0003, Top 5: 0.0016, Top 10: 0.0032
Test unseen_both completed. Accuracy: 0.0004, Top 5: 0.0020, Top 10: 0.0040
Testing completed in 0.79m.
Epoch 3 completed in 9.73m. 0.07m per recording.


Training Epoch 4: 100%|██████████| 135/135 [08:50<00:00,  3.93s/it]


Epoch 4 completed. Loss: 9.8474, Clip Loss: 9.8474, MSE Loss: 1116468011694470.1250
Accuracy: 0.0003, Top 5: 0.0015, Top 10: 0.0029
Test unseen_subject completed. Accuracy: 0.0003, Top 5: 0.0017, Top 10: 0.0035
Test unseen_task completed. Accuracy: 0.0004, Top 5: 0.0021, Top 10: 0.0038
Test unseen_both completed. Accuracy: 0.0006, Top 5: 0.0025, Top 10: 0.0044
Testing completed in 0.76m.
Epoch 4 completed in 9.60m. 0.07m per recording.


[36m(raylet)[0m Spilled 146024 MiB, 217 objects, write throughput 926 MiB/s.
Training Epoch 5: 100%|██████████| 135/135 [08:54<00:00,  3.96s/it]


Epoch 5 completed. Loss: 9.8343, Clip Loss: 9.8343, MSE Loss: 2768345701254063.0000
Accuracy: 0.0003, Top 5: 0.0016, Top 10: 0.0030
Test unseen_subject completed. Accuracy: 0.0004, Top 5: 0.0019, Top 10: 0.0037
Test unseen_task completed. Accuracy: 0.0004, Top 5: 0.0019, Top 10: 0.0037
Test unseen_both completed. Accuracy: 0.0004, Top 5: 0.0022, Top 10: 0.0044
Testing completed in 0.78m.
Epoch 5 completed in 9.68m. 0.07m per recording.


Training Epoch 6: 100%|██████████| 135/135 [08:53<00:00,  3.95s/it]


Epoch 6 completed. Loss: 9.8243, Clip Loss: 9.8243, MSE Loss: 4902646479886874.0000
Accuracy: 0.0003, Top 5: 0.0016, Top 10: 0.0031
Test unseen_subject completed. Accuracy: 0.0004, Top 5: 0.0018, Top 10: 0.0036
Test unseen_task completed. Accuracy: 0.0004, Top 5: 0.0019, Top 10: 0.0037
Test unseen_both completed. Accuracy: 0.0004, Top 5: 0.0020, Top 10: 0.0036
Testing completed in 0.77m.
Epoch 6 completed in 9.66m. 0.07m per recording.


Training Epoch 7:   6%|▌         | 8/135 [01:06<10:07,  4.78s/it]  