In [1]:
data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "schoffelen2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

In [2]:
# from utils.compression import compress_directories, decompress_directories

# for base_path, batch_type in data_partition.items():
#     decompress_directories(
#         source_path=f'downloaded_data/{base_path}',
#         destination_path=f'data/{base_path}',
#         checksum_file_name="checksums.txt",
#         delete_compressed_files=True,
#         num_workers=26
#     )

In [3]:
from config import SimpleConvConfig
from models.simpleconv import SimpleConv
import torch

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=False,
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

In [None]:
from train.training_session_v0 import TrainingSessionV0
from config import TrainingConfigV0
import multiprocessing

config = TrainingConfigV0(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Pre-processing parameters
    # Brain
    new_freq=100,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=20,
    baseline_window=0.5,
    notch_filter=True,
    scaling="standard",
    delay=0.15,
    # Hyperparameters
    learning_rate=3e-4,
    weight_decay=1e-4,
    epochs=50,
    batch_size=256,
    use_clip_loss=True,
    use_mse_loss=True,
    alpha=0.6,
    random_test_size=10,
    seed=42,
)

session = TrainingSessionV0(
    config=config,
    studies={study: "audio" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase1/ablation/delay/0_15",
    clear_cache=False,
)

Loading Gwilliams2023 with batch type audio
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.

GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.
SimpleConv initialized with 8448160 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 6, hidden_dim: 384, params 7973376


In [5]:
try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=multiprocessing.cpu_count() - 2,
        max_cache_size=400,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

2024-12-20 13:38:24,453	INFO worker.py:1821 -- Started a local Ray instance.
Training Epoch 1: 100%|██████████| 135/135 [04:15<00:00,  1.90s/it]


Testing at epoch 1
Test unseen_subject completed. Accuracy: 0.0225, Top 1: 0.0418, Top 5: 0.1934, Top 10: 0.3351, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.0365, Top 1: 0.0556, Top 5: 0.2296, Top 10: 0.3951, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0323, Top 1: 0.0519, Top 5: 0.2140, Top 10: 0.3663, Perplexity: 0.0000
Epoch 1 completed in 4.65m. 0.03m per recording.


Training Epoch 2: 100%|██████████| 135/135 [04:15<00:00,  1.90s/it]


Testing at epoch 2
Test unseen_subject completed. Accuracy: 0.1242, Top 1: 0.1857, Top 5: 0.4266, Top 10: 0.5523, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1153, Top 1: 0.1614, Top 5: 0.4273, Top 10: 0.5840, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1103, Top 1: 0.1471, Top 5: 0.3879, Top 10: 0.5346, Perplexity: 0.0000
Epoch 2 completed in 4.61m. 0.03m per recording.


Training Epoch 3: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 3
Test unseen_subject completed. Accuracy: 0.0728, Top 1: 0.1086, Top 5: 0.3465, Top 10: 0.5042, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1078, Top 1: 0.1522, Top 5: 0.4084, Top 10: 0.5593, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0695, Top 1: 0.1039, Top 5: 0.3166, Top 10: 0.4574, Perplexity: 0.0000
Epoch 3 completed in 4.64m. 0.03m per recording.


Training Epoch 4: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 4
Test unseen_subject completed. Accuracy: 0.1313, Top 1: 0.1795, Top 5: 0.4036, Top 10: 0.5277, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1613, Top 1: 0.2221, Top 5: 0.5090, Top 10: 0.6431, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1237, Top 1: 0.1775, Top 5: 0.3952, Top 10: 0.5253, Perplexity: 0.0000
Epoch 4 completed in 4.62m. 0.03m per recording.


Training Epoch 5: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 5
Test unseen_subject completed. Accuracy: 0.1152, Top 1: 0.1693, Top 5: 0.4493, Top 10: 0.5856, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1774, Top 1: 0.2439, Top 5: 0.5392, Top 10: 0.6821, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1321, Top 1: 0.1811, Top 5: 0.4494, Top 10: 0.5828, Perplexity: 0.0000
Epoch 5 completed in 4.63m. 0.03m per recording.


Training Epoch 6: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 6
Test unseen_subject completed. Accuracy: 0.1386, Top 1: 0.1882, Top 5: 0.4466, Top 10: 0.5704, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2340, Top 1: 0.3044, Top 5: 0.5958, Top 10: 0.7212, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1649, Top 1: 0.2324, Top 5: 0.4498, Top 10: 0.5662, Perplexity: 0.0000
Epoch 6 completed in 4.66m. 0.03m per recording.


Training Epoch 7: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 7
Test unseen_subject completed. Accuracy: 0.1497, Top 1: 0.2101, Top 5: 0.4712, Top 10: 0.6008, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1976, Top 1: 0.2618, Top 5: 0.5651, Top 10: 0.7048, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1518, Top 1: 0.2069, Top 5: 0.4542, Top 10: 0.5852, Perplexity: 0.0000
Epoch 7 completed in 4.64m. 0.03m per recording.


Training Epoch 8: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 8
Test unseen_subject completed. Accuracy: 0.1215, Top 1: 0.1691, Top 5: 0.4294, Top 10: 0.5845, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1109, Top 1: 0.1556, Top 5: 0.4428, Top 10: 0.6042, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1054, Top 1: 0.1531, Top 5: 0.4169, Top 10: 0.5674, Perplexity: 0.0000
Epoch 8 completed in 4.64m. 0.03m per recording.


Training Epoch 9: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 9
Test unseen_subject completed. Accuracy: 0.1760, Top 1: 0.2440, Top 5: 0.5163, Top 10: 0.6342, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2061, Top 1: 0.2740, Top 5: 0.5655, Top 10: 0.7042, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1763, Top 1: 0.2438, Top 5: 0.4788, Top 10: 0.5863, Perplexity: 0.0000
Epoch 9 completed in 4.65m. 0.03m per recording.


Training Epoch 10: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 10
Test unseen_subject completed. Accuracy: 0.1776, Top 1: 0.2495, Top 5: 0.5040, Top 10: 0.6089, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1702, Top 1: 0.2359, Top 5: 0.5184, Top 10: 0.6534, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1459, Top 1: 0.2171, Top 5: 0.4469, Top 10: 0.5715, Perplexity: 0.0000
Epoch 10 completed in 4.62m. 0.03m per recording.


Training Epoch 11: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 11
Test unseen_subject completed. Accuracy: 0.1734, Top 1: 0.2309, Top 5: 0.5052, Top 10: 0.6219, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1868, Top 1: 0.2506, Top 5: 0.5505, Top 10: 0.6812, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1682, Top 1: 0.2173, Top 5: 0.4806, Top 10: 0.6017, Perplexity: 0.0000
Epoch 11 completed in 4.63m. 0.03m per recording.


Training Epoch 12: 100%|██████████| 135/135 [04:17<00:00,  1.90s/it]


Testing at epoch 12
Test unseen_subject completed. Accuracy: 0.2031, Top 1: 0.2730, Top 5: 0.5271, Top 10: 0.6289, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2160, Top 1: 0.2858, Top 5: 0.5932, Top 10: 0.7203, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1849, Top 1: 0.2424, Top 5: 0.4896, Top 10: 0.6024, Perplexity: 0.0000
Epoch 12 completed in 4.63m. 0.03m per recording.


Training Epoch 13: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 13
Test unseen_subject completed. Accuracy: 0.1821, Top 1: 0.2420, Top 5: 0.5041, Top 10: 0.6230, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1969, Top 1: 0.2629, Top 5: 0.5531, Top 10: 0.6806, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1141, Top 1: 0.1752, Top 5: 0.3986, Top 10: 0.5340, Perplexity: 0.0000
Epoch 13 completed in 4.63m. 0.03m per recording.


Training Epoch 14: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 14
Test unseen_subject completed. Accuracy: 0.1369, Top 1: 0.2029, Top 5: 0.4989, Top 10: 0.6203, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1939, Top 1: 0.2652, Top 5: 0.5878, Top 10: 0.7283, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1766, Top 1: 0.2484, Top 5: 0.4960, Top 10: 0.6087, Perplexity: 0.0000
Epoch 14 completed in 4.64m. 0.03m per recording.


Training Epoch 15: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 15
Test unseen_subject completed. Accuracy: 0.1840, Top 1: 0.2465, Top 5: 0.4981, Top 10: 0.6183, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2285, Top 1: 0.3059, Top 5: 0.6017, Top 10: 0.7252, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1702, Top 1: 0.2166, Top 5: 0.4463, Top 10: 0.5825, Perplexity: 0.0000
Epoch 15 completed in 4.64m. 0.03m per recording.


Training Epoch 16: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 16
Test unseen_subject completed. Accuracy: 0.1721, Top 1: 0.2395, Top 5: 0.4986, Top 10: 0.6144, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2143, Top 1: 0.2897, Top 5: 0.5855, Top 10: 0.7174, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1452, Top 1: 0.2001, Top 5: 0.4487, Top 10: 0.5689, Perplexity: 0.0000
Epoch 16 completed in 4.64m. 0.03m per recording.


Training Epoch 17: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 17
Test unseen_subject completed. Accuracy: 0.1403, Top 1: 0.1913, Top 5: 0.4552, Top 10: 0.5460, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1477, Top 1: 0.2054, Top 5: 0.4737, Top 10: 0.6120, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1304, Top 1: 0.1842, Top 5: 0.4094, Top 10: 0.5317, Perplexity: 0.0000
Epoch 17 completed in 4.65m. 0.03m per recording.


Training Epoch 18: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 18
Test unseen_subject completed. Accuracy: 0.1835, Top 1: 0.2465, Top 5: 0.5095, Top 10: 0.6078, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2150, Top 1: 0.2798, Top 5: 0.5572, Top 10: 0.6884, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1993, Top 1: 0.2534, Top 5: 0.4810, Top 10: 0.5901, Perplexity: 0.0000
Epoch 18 completed in 4.64m. 0.03m per recording.


Training Epoch 19: 100%|██████████| 135/135 [04:17<00:00,  1.90s/it]


Testing at epoch 19
Test unseen_subject completed. Accuracy: 0.1236, Top 1: 0.1683, Top 5: 0.4262, Top 10: 0.5541, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1587, Top 1: 0.2172, Top 5: 0.4990, Top 10: 0.6418, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1253, Top 1: 0.1695, Top 5: 0.3994, Top 10: 0.5334, Perplexity: 0.0000
Epoch 19 completed in 4.63m. 0.03m per recording.


Training Epoch 20: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 20
Test unseen_subject completed. Accuracy: 0.1687, Top 1: 0.2272, Top 5: 0.4982, Top 10: 0.6159, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2005, Top 1: 0.2718, Top 5: 0.5610, Top 10: 0.6944, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1788, Top 1: 0.2301, Top 5: 0.4712, Top 10: 0.5652, Perplexity: 0.0000
Epoch 20 completed in 4.64m. 0.03m per recording.


Training Epoch 21: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 21
Test unseen_subject completed. Accuracy: 0.1048, Top 1: 0.1469, Top 5: 0.4144, Top 10: 0.5442, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1610, Top 1: 0.2174, Top 5: 0.4947, Top 10: 0.6353, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1059, Top 1: 0.1585, Top 5: 0.4032, Top 10: 0.5243, Perplexity: 0.0000
Epoch 21 completed in 4.63m. 0.03m per recording.


Training Epoch 22: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 22
Test unseen_subject completed. Accuracy: 0.1692, Top 1: 0.2238, Top 5: 0.4758, Top 10: 0.5986, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1443, Top 1: 0.1987, Top 5: 0.4655, Top 10: 0.6091, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1005, Top 1: 0.1556, Top 5: 0.3849, Top 10: 0.5198, Perplexity: 0.0000
Epoch 22 completed in 4.62m. 0.03m per recording.


Training Epoch 23: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 23
Test unseen_subject completed. Accuracy: 0.1302, Top 1: 0.1843, Top 5: 0.4133, Top 10: 0.5444, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1337, Top 1: 0.1872, Top 5: 0.4606, Top 10: 0.6090, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1129, Top 1: 0.1592, Top 5: 0.3871, Top 10: 0.5145, Perplexity: 0.0000
Epoch 23 completed in 4.64m. 0.03m per recording.
Early stopping at epoch 23. Highest top 10 accuracy at epoch 12.
Training completed.
unseen_subject: Acc: 0.2640, Top 1: 0.2640, Top 5: 0.6294, Top 10: 0.7310
unseen_task: Acc: 0.1677, Top 1: 0.2341, Top 5: 0.4900, Top 10: 0.6487
unseen_both: Acc: 0.2115, Top 1: 0.2981, Top 5: 0.5529, Top 10: 0.6635


In [6]:
session.highest_epoch

12