In [1]:
# from utils.compression import compress_directories, decompress_directories

# for base_path, batch_type in data_partition.items():
#     decompress_directories(
#         source_path=f'downloaded_data/{base_path}',
#         destination_path=f'data/{base_path}',
#         checksum_file_name="checksums.txt",
#         delete_compressed_files=True,
#         num_workers=26
#     )

In [3]:
# del session.logger
# del session

from pickle import NONE
from train.training_session_v0 import TrainingSessionV0
from config import TrainingConfigV0
import multiprocessing
from config import SimpleConvConfig
from models.simpleconv import SimpleConv
import torch

data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "armeini2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=False,
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    half=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

config = TrainingConfigV0(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Pre-processing parameters
    # Brain
    new_freq=100,
    frequency_bands={"all": (0.5, 40)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=None,
    baseline_window=0.0,
    notch_filter=True,
    scaling="standard",
    delay=0.15,
    # Hyperparameters
    learning_rate=3e-4,
    weight_decay=1e-4,
    epochs=50,
    batch_size=256,
    use_clip_loss=True,
    use_mse_loss=True,
    alpha=0.6,
    random_test_size=10,
    seed=42,
)

session = TrainingSessionV0(
    config=config,
    studies={study: "audio" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase1/objectives/CLIP_MSE",
    clear_cache=False,
    cache_name="cache/1",
)

try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=(multiprocessing.cpu_count() - 2) // 2,
        max_cache_size=400,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

Loading Gwilliams2023 with batch type audio
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.

GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.
SimpleConv initialized with 8448160 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 6, hidden_dim: 384, params 7973376


Training Epoch 1: 100%|██████████| 135/135 [05:06<00:00,  2.27s/it]


Epoch 1 completed. Loss: 3.3592, Clip Loss: 5.0375, MSE Loss: 0.8417
Accuracy: 0.0499, Top 5: 0.1583, Top 10: 0.2434
Test unseen_subject completed. Accuracy: 0.0705, Top 5: 0.2010, Top 10: 0.3080
Test unseen_task completed. Accuracy: 0.0576, Top 5: 0.1725, Top 10: 0.2664
Test unseen_both completed. Accuracy: 0.0712, Top 5: 0.1737, Top 10: 0.2587
Testing completed in 0.46m.
Epoch 1 completed in 5.56m. 0.04m per recording.


Training Epoch 2: 100%|██████████| 135/135 [05:15<00:00,  2.34s/it]


Epoch 2 completed. Loss: 2.7826, Clip Loss: 4.3002, MSE Loss: 0.5063
Accuracy: 0.0900, Top 5: 0.2550, Top 10: 0.3697
Test unseen_subject completed. Accuracy: 0.1105, Top 5: 0.2735, Top 10: 0.3825
Test unseen_task completed. Accuracy: 0.1018, Top 5: 0.2851, Top 10: 0.4047
Test unseen_both completed. Accuracy: 0.0963, Top 5: 0.2675, Top 10: 0.4113
Testing completed in 0.46m.
Epoch 2 completed in 5.72m. 0.04m per recording.


Training Epoch 3: 100%|██████████| 135/135 [05:04<00:00,  2.26s/it]


Epoch 3 completed. Loss: 2.6340, Clip Loss: 4.0594, MSE Loss: 0.4959
Accuracy: 0.1246, Top 5: 0.3179, Top 10: 0.4400
Test unseen_subject completed. Accuracy: 0.0690, Top 5: 0.2445, Top 10: 0.3415
Test unseen_task completed. Accuracy: 0.1009, Top 5: 0.2669, Top 10: 0.3892
Test unseen_both completed. Accuracy: 0.0675, Top 5: 0.2175, Top 10: 0.3312
Testing completed in 0.41m.
Epoch 3 completed in 5.49m. 0.04m per recording.


Training Epoch 4: 100%|██████████| 135/135 [04:32<00:00,  2.02s/it]


Epoch 4 completed. Loss: 2.5610, Clip Loss: 3.9382, MSE Loss: 0.4951
Accuracy: 0.1466, Top 5: 0.3553, Top 10: 0.4781
Test unseen_subject completed. Accuracy: 0.1240, Top 5: 0.3255, Top 10: 0.4415
Test unseen_task completed. Accuracy: 0.1398, Top 5: 0.3666, Top 10: 0.4925
Test unseen_both completed. Accuracy: 0.1225, Top 5: 0.3275, Top 10: 0.4375
Testing completed in 0.36m.
Epoch 4 completed in 4.90m. 0.04m per recording.


Training Epoch 5: 100%|██████████| 135/135 [04:45<00:00,  2.12s/it]


Epoch 5 completed. Loss: 2.4652, Clip Loss: 3.7828, MSE Loss: 0.4889
Accuracy: 0.1690, Top 5: 0.3922, Top 10: 0.5185
Test unseen_subject completed. Accuracy: 0.0775, Top 5: 0.2660, Top 10: 0.3930
Test unseen_task completed. Accuracy: 0.1302, Top 5: 0.3192, Top 10: 0.4338
Test unseen_both completed. Accuracy: 0.0988, Top 5: 0.2800, Top 10: 0.3650
Testing completed in 0.37m.
Epoch 5 completed in 5.13m. 0.04m per recording.


Training Epoch 6: 100%|██████████| 135/135 [05:54<00:00,  2.62s/it]


Epoch 6 completed. Loss: 2.3910, Clip Loss: 3.6605, MSE Loss: 0.4869
Accuracy: 0.1868, Top 5: 0.4172, Top 10: 0.5423
Test unseen_subject completed. Accuracy: 0.1500, Top 5: 0.3610, Top 10: 0.4705
Test unseen_task completed. Accuracy: 0.1876, Top 5: 0.4233, Top 10: 0.5483
Test unseen_both completed. Accuracy: 0.1713, Top 5: 0.3738, Top 10: 0.4850
Testing completed in 0.41m.
Epoch 6 completed in 6.31m. 0.05m per recording.


Training Epoch 7: 100%|██████████| 135/135 [05:57<00:00,  2.65s/it]


Epoch 7 completed. Loss: 2.3592, Clip Loss: 3.6086, MSE Loss: 0.4850
Accuracy: 0.1963, Top 5: 0.4327, Top 10: 0.5571
Test unseen_subject completed. Accuracy: 0.0955, Top 5: 0.2865, Top 10: 0.3900
Test unseen_task completed. Accuracy: 0.1700, Top 5: 0.3870, Top 10: 0.5252
Test unseen_both completed. Accuracy: 0.1125, Top 5: 0.3100, Top 10: 0.4375
Testing completed in 0.44m.
Epoch 7 completed in 6.41m. 0.05m per recording.


Training Epoch 8: 100%|██████████| 135/135 [05:51<00:00,  2.60s/it]


Epoch 8 completed. Loss: 2.2803, Clip Loss: 3.4792, MSE Loss: 0.4820
Accuracy: 0.2140, Top 5: 0.4579, Top 10: 0.5814
Test unseen_subject completed. Accuracy: 0.1650, Top 5: 0.3700, Top 10: 0.4910
Test unseen_task completed. Accuracy: 0.1183, Top 5: 0.3140, Top 10: 0.4326
Test unseen_both completed. Accuracy: 0.1475, Top 5: 0.3300, Top 10: 0.4462
Testing completed in 0.40m.
Epoch 8 completed in 6.25m. 0.05m per recording.


Training Epoch 9: 100%|██████████| 135/135 [05:50<00:00,  2.59s/it]


Epoch 9 completed. Loss: 2.2086, Clip Loss: 3.3615, MSE Loss: 0.4793
Accuracy: 0.2318, Top 5: 0.4846, Top 10: 0.6085
Test unseen_subject completed. Accuracy: 0.1510, Top 5: 0.3620, Top 10: 0.4865
Test unseen_task completed. Accuracy: 0.1986, Top 5: 0.4337, Top 10: 0.5603
Test unseen_both completed. Accuracy: 0.1713, Top 5: 0.3488, Top 10: 0.4537
Testing completed in 0.42m.
Epoch 9 completed in 6.25m. 0.05m per recording.


Training Epoch 10: 100%|██████████| 135/135 [05:49<00:00,  2.59s/it]


Epoch 10 completed. Loss: 2.1941, Clip Loss: 3.3384, MSE Loss: 0.4777
Accuracy: 0.2366, Top 5: 0.4908, Top 10: 0.6129
Test unseen_subject completed. Accuracy: 0.1920, Top 5: 0.4095, Top 10: 0.5140
Test unseen_task completed. Accuracy: 0.2142, Top 5: 0.4682, Top 10: 0.5968
Test unseen_both completed. Accuracy: 0.1688, Top 5: 0.3688, Top 10: 0.4863
Testing completed in 0.47m.
Epoch 10 completed in 6.30m. 0.05m per recording.


Training Epoch 11: 100%|██████████| 135/135 [05:50<00:00,  2.60s/it]


Epoch 11 completed. Loss: 2.0978, Clip Loss: 3.1792, MSE Loss: 0.4758
Accuracy: 0.2640, Top 5: 0.5214, Top 10: 0.6443
Test unseen_subject completed. Accuracy: 0.1660, Top 5: 0.3870, Top 10: 0.4920
Test unseen_task completed. Accuracy: 0.1782, Top 5: 0.4058, Top 10: 0.5359
Test unseen_both completed. Accuracy: 0.1375, Top 5: 0.3563, Top 10: 0.4750
Testing completed in 0.42m.
Epoch 11 completed in 6.27m. 0.05m per recording.


Training Epoch 12: 100%|██████████| 135/135 [05:51<00:00,  2.61s/it]


Epoch 12 completed. Loss: 2.0549, Clip Loss: 3.1069, MSE Loss: 0.4768
Accuracy: 0.2771, Top 5: 0.5379, Top 10: 0.6567
Test unseen_subject completed. Accuracy: 0.1785, Top 5: 0.3835, Top 10: 0.4815
Test unseen_task completed. Accuracy: 0.2208, Top 5: 0.4545, Top 10: 0.5772
Test unseen_both completed. Accuracy: 0.1713, Top 5: 0.3575, Top 10: 0.4750
Testing completed in 0.41m.
Epoch 12 completed in 6.27m. 0.05m per recording.


Training Epoch 13: 100%|██████████| 135/135 [05:57<00:00,  2.64s/it]


Epoch 13 completed. Loss: 1.9688, Clip Loss: 2.9644, MSE Loss: 0.4754
Accuracy: 0.2968, Top 5: 0.5684, Top 10: 0.6876
Test unseen_subject completed. Accuracy: 0.1740, Top 5: 0.3880, Top 10: 0.4970
Test unseen_task completed. Accuracy: 0.1888, Top 5: 0.4054, Top 10: 0.5236
Test unseen_both completed. Accuracy: 0.1888, Top 5: 0.3900, Top 10: 0.4875
Testing completed in 0.44m.
Epoch 13 completed in 6.39m. 0.05m per recording.


Training Epoch 14: 100%|██████████| 135/135 [05:54<00:00,  2.63s/it]


Epoch 14 completed. Loss: 1.8926, Clip Loss: 2.8381, MSE Loss: 0.4742
Accuracy: 0.3195, Top 5: 0.5933, Top 10: 0.7083
Test unseen_subject completed. Accuracy: 0.2040, Top 5: 0.4170, Top 10: 0.5220
Test unseen_task completed. Accuracy: 0.1936, Top 5: 0.4292, Top 10: 0.5550
Test unseen_both completed. Accuracy: 0.1713, Top 5: 0.3713, Top 10: 0.4800
Testing completed in 0.41m.
Epoch 14 completed in 6.32m. 0.05m per recording.


Training Epoch 15: 100%|██████████| 135/135 [05:52<00:00,  2.61s/it]


Epoch 15 completed. Loss: 1.7685, Clip Loss: 2.6322, MSE Loss: 0.4731
Accuracy: 0.3575, Top 5: 0.6378, Top 10: 0.7482
Test unseen_subject completed. Accuracy: 0.2025, Top 5: 0.4135, Top 10: 0.5205
Test unseen_task completed. Accuracy: 0.1904, Top 5: 0.4103, Top 10: 0.5274
Test unseen_both completed. Accuracy: 0.1900, Top 5: 0.3563, Top 10: 0.4650
Testing completed in 0.42m.
Epoch 15 completed in 6.30m. 0.05m per recording.


Training Epoch 16: 100%|██████████| 135/135 [05:51<00:00,  2.60s/it]


Epoch 16 completed. Loss: 1.6533, Clip Loss: 2.4398, MSE Loss: 0.4735
Accuracy: 0.3911, Top 5: 0.6730, Top 10: 0.7767
Test unseen_subject completed. Accuracy: 0.2050, Top 5: 0.4280, Top 10: 0.5315
Test unseen_task completed. Accuracy: 0.2087, Top 5: 0.4420, Top 10: 0.5586
Test unseen_both completed. Accuracy: 0.1888, Top 5: 0.3950, Top 10: 0.4875
Testing completed in 0.41m.
Epoch 16 completed in 6.26m. 0.05m per recording.


Training Epoch 17: 100%|██████████| 135/135 [05:55<00:00,  2.63s/it]


Epoch 17 completed. Loss: 1.5393, Clip Loss: 2.2500, MSE Loss: 0.4732
Accuracy: 0.4310, Top 5: 0.7113, Top 10: 0.8081
Test unseen_subject completed. Accuracy: 0.1080, Top 5: 0.2835, Top 10: 0.4120
Test unseen_task completed. Accuracy: 0.1623, Top 5: 0.3707, Top 10: 0.4961
Test unseen_both completed. Accuracy: 0.1212, Top 5: 0.3162, Top 10: 0.4213
Testing completed in 0.43m.
Epoch 17 completed in 6.36m. 0.05m per recording.


Training Epoch 18: 100%|██████████| 135/135 [05:57<00:00,  2.65s/it]


Epoch 18 completed. Loss: 1.4341, Clip Loss: 2.0749, MSE Loss: 0.4730
Accuracy: 0.4704, Top 5: 0.7446, Top 10: 0.8338
Test unseen_subject completed. Accuracy: 0.1795, Top 5: 0.4040, Top 10: 0.5065
Test unseen_task completed. Accuracy: 0.1930, Top 5: 0.4269, Top 10: 0.5469
Test unseen_both completed. Accuracy: 0.1600, Top 5: 0.3900, Top 10: 0.4975
Testing completed in 0.43m.
Epoch 18 completed in 6.40m. 0.05m per recording.


Training Epoch 19: 100%|██████████| 135/135 [06:00<00:00,  2.67s/it]


Epoch 19 completed. Loss: 1.2877, Clip Loss: 1.8306, MSE Loss: 0.4733
Accuracy: 0.5199, Top 5: 0.7863, Top 10: 0.8680
Test unseen_subject completed. Accuracy: 0.1625, Top 5: 0.3655, Top 10: 0.4745
Test unseen_task completed. Accuracy: 0.2157, Top 5: 0.4440, Top 10: 0.5638
Test unseen_both completed. Accuracy: 0.1775, Top 5: 0.3725, Top 10: 0.4662
Testing completed in 0.48m.
Epoch 19 completed in 6.49m. 0.05m per recording.


Training Epoch 20: 100%|██████████| 135/135 [05:55<00:00,  2.63s/it]


Epoch 20 completed. Loss: 1.1552, Clip Loss: 1.6096, MSE Loss: 0.4737
Accuracy: 0.5694, Top 5: 0.8250, Top 10: 0.8945
Test unseen_subject completed. Accuracy: 0.1670, Top 5: 0.3760, Top 10: 0.4725
Test unseen_task completed. Accuracy: 0.1936, Top 5: 0.4316, Top 10: 0.5489
Test unseen_both completed. Accuracy: 0.1412, Top 5: 0.3325, Top 10: 0.4525
Testing completed in 0.48m.
Epoch 20 completed in 6.41m. 0.05m per recording.


Training Epoch 21: 100%|██████████| 135/135 [05:49<00:00,  2.59s/it]


Epoch 21 completed. Loss: 1.0314, Clip Loss: 1.4038, MSE Loss: 0.4727
Accuracy: 0.6127, Top 5: 0.8558, Top 10: 0.9167
Test unseen_subject completed. Accuracy: 0.1500, Top 5: 0.3570, Top 10: 0.4605
Test unseen_task completed. Accuracy: 0.1491, Top 5: 0.3682, Top 10: 0.4897
Test unseen_both completed. Accuracy: 0.1200, Top 5: 0.2938, Top 10: 0.4175
Testing completed in 0.42m.
Epoch 21 completed in 6.26m. 0.05m per recording.


Training Epoch 22: 100%|██████████| 135/135 [05:55<00:00,  2.63s/it]


Epoch 22 completed. Loss: 0.9088, Clip Loss: 1.1987, MSE Loss: 0.4738
Accuracy: 0.6602, Top 5: 0.8854, Top 10: 0.9363
Test unseen_subject completed. Accuracy: 0.1200, Top 5: 0.3020, Top 10: 0.4090
Test unseen_task completed. Accuracy: 0.1554, Top 5: 0.3753, Top 10: 0.4959
Test unseen_both completed. Accuracy: 0.1225, Top 5: 0.3050, Top 10: 0.4125
Testing completed in 0.42m.
Epoch 22 completed in 6.35m. 0.05m per recording.


Training Epoch 23: 100%|██████████| 135/135 [05:51<00:00,  2.60s/it]


Epoch 23 completed. Loss: 0.7489, Clip Loss: 0.9340, MSE Loss: 0.4712
Accuracy: 0.7304, Top 5: 0.9250, Top 10: 0.9612
Test unseen_subject completed. Accuracy: 0.1380, Top 5: 0.3300, Top 10: 0.4425
Test unseen_task completed. Accuracy: 0.1583, Top 5: 0.3801, Top 10: 0.4928
Test unseen_both completed. Accuracy: 0.1300, Top 5: 0.2863, Top 10: 0.4150
Testing completed in 0.45m.
Epoch 23 completed in 6.30m. 0.05m per recording.


Training Epoch 24: 100%|██████████| 135/135 [05:52<00:00,  2.61s/it]


Epoch 24 completed. Loss: 0.7188, Clip Loss: 0.8832, MSE Loss: 0.4722
Accuracy: 0.7418, Top 5: 0.9282, Top 10: 0.9628
Test unseen_subject completed. Accuracy: 0.1550, Top 5: 0.3370, Top 10: 0.4520
Test unseen_task completed. Accuracy: 0.1484, Top 5: 0.3525, Top 10: 0.4695
Test unseen_both completed. Accuracy: 0.1425, Top 5: 0.3137, Top 10: 0.4150
Testing completed in 0.43m.
Epoch 24 completed in 6.31m. 0.05m per recording.


Training Epoch 25: 100%|██████████| 135/135 [05:59<00:00,  2.66s/it]


Epoch 25 completed. Loss: 0.6780, Clip Loss: 0.8160, MSE Loss: 0.4710
Accuracy: 0.7604, Top 5: 0.9351, Top 10: 0.9683
Test unseen_subject completed. Accuracy: 0.1575, Top 5: 0.3615, Top 10: 0.4505
Test unseen_task completed. Accuracy: 0.1567, Top 5: 0.3714, Top 10: 0.4953
Test unseen_both completed. Accuracy: 0.1263, Top 5: 0.2750, Top 10: 0.3925
Testing completed in 0.42m.
Epoch 25 completed in 6.40m. 0.05m per recording.


Training Epoch 26: 100%|██████████| 135/135 [05:55<00:00,  2.64s/it]


Epoch 26 completed. Loss: 1.7646, Clip Loss: 2.5176, MSE Loss: 0.6351
Accuracy: 0.4729, Top 5: 0.6632, Top 10: 0.7338
Test unseen_subject completed. Accuracy: 0.0325, Top 5: 0.1435, Top 10: 0.2435
Test unseen_task completed. Accuracy: 0.0369, Top 5: 0.1478, Top 10: 0.2407
Test unseen_both completed. Accuracy: 0.0275, Top 5: 0.1400, Top 10: 0.2150
Testing completed in 0.43m.
Epoch 26 completed in 6.37m. 0.05m per recording.


Training Epoch 27: 100%|██████████| 135/135 [04:19<00:00,  1.93s/it]


Epoch 27 completed. Loss: 2.9453, Clip Loss: 4.4905, MSE Loss: 0.6276
Accuracy: 0.1074, Top 5: 0.2814, Top 10: 0.3909
Test unseen_subject completed. Accuracy: 0.0510, Top 5: 0.1700, Top 10: 0.2600
Test unseen_task completed. Accuracy: 0.0560, Top 5: 0.1739, Top 10: 0.2630
Test unseen_both completed. Accuracy: 0.0488, Top 5: 0.1475, Top 10: 0.2275
Testing completed in 0.36m.
Epoch 27 completed in 4.69m. 0.03m per recording.
Early stopping at epoch 27. Highest top 10 accuracy at epoch 16.
Training completed.
unseen_subject: Acc: 0.2050, Top 5: 0.4280, Top 10: 0.5315
unseen_task: Acc: 0.2087, Top 5: 0.4420, Top 10: 0.5586
unseen_both: Acc: 0.1888, Top 5: 0.3950, Top 10: 0.4875


## CLIP MSE

In [4]:
# # del session.logger
# # del session

# from pickle import NONE
# from train.training_session_v0 import TrainingSessionV0
# from config import TrainingConfigV0
# import multiprocessing
# from config import SimpleConvConfig
# from models.simpleconv import SimpleConv
# import torch

# data_partition = {
#     "gwilliams2023": {
#         "testing_subjects": [19, 20, 21],
#         "testing_tasks": [0],
#     },
#     # "armeini2022": {
#     #     "testing_subjects": [],
#     #     "testing_tasks": [8, 9],
#     # },
# }

# model_config = SimpleConvConfig(
#     # Str to list of possible conditions
#     mel_normalization=False,
#     conditions={
#         "study": [],
#         "subject": [],
#     },
#     # Channels
#     in_channels=208,
#     out_channels=128,
#     hidden_dim=384,
#     dropout=0.2,
#     initial_batch_norm=True,
#     # Sensor layout settings
#     layout_dim=2,
#     layout_proj=True,
#     layout_scaling="minmax",
#     # Merger with spatial attn
#     merger=False,
#     merger_emb_type=None,
#     merger_emb_dim=0,
#     merger_channels=0,
#     merger_dropout=False,
#     merger_conditional=None,
#     # Inital
#     initial_linear=384,
#     initial_depth=1,
#     # Conditional layers
#     conditional_layers=False,
#     conditional_layers_dim=None,  # input or hidden_dim
#     # Conv layer overall structure
#     depth=6,
#     kernel_size=3,
#     growth=1.0,
#     dilation_growth=2,
#     dilation_period=5,
#     glu=1,
#     conv_dropout=0.2,
#     dropout_input=0.2,
#     batch_norm=True,
#     half=False,
#     # Quantizer
#     quantizer=False,
#     num_codebooks=0,
#     codebook_size=0,
#     quantizer_commitment=0,
#     quantizer_temp_init=0,
#     quantizer_temp_min=0,
#     quantizer_temp_decay=0,
#     # Transformers Encoders
#     transformer_input=None,
#     transformer_encoder_emb=None,
#     transformer_encoder_layers=0,
#     transformer_encoder_heads=0,
#     # Transformer Decoders
#     transformer_decoder_emb=None,
#     transformer_decoder_layers=0,
#     transformer_decoder_heads=0,
#     transformer_decoder_dim=0,
# )

# config = TrainingConfigV0(
#     brain_encoder_config=model_config,
#     data_partition=data_partition,
#     # Pre-processing parameters
#     # Brain
#     new_freq=100,
#     frequency_bands={"all": (0.5, 40)},
#     max_random_shift=1.0,
#     window_size=4,
#     window_stride=1,
#     brain_clipping=None,
#     baseline_window=0.0,
#     notch_filter=True,
#     scaling="standard",
#     delay=0.15,
#     # Hyperparameters
#     learning_rate=3e-4,
#     weight_decay=1e-4,
#     epochs=50,
#     batch_size=256,
#     use_clip_loss=True,
#     use_mse_loss=True,
#     alpha=0.6,
#     random_test_size=10,
#     seed=42,
# )

# config.notch_filter = False

# session = TrainingSessionV0(
#     config=config,
#     studies={study: "audio" for study in data_partition.keys()},
#     data_path="data",
#     save_path="saves/phase1/ablation/notch_filter/False",
#     clear_cache=True,
#     cache_name="cache/1",
# )

# try:
#     session.train(
#         device="cuda",
#         buffer_size=30,
#         num_workers=(multiprocessing.cpu_count() - 2) // 2,
#         max_cache_size=400,
#         current_epoch=0,
#     )
# except KeyboardInterrupt as e:
#     print("Exited")

## No notch