In [1]:
data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "schoffelen2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

In [2]:
# from utils.compression import compress_directories, decompress_directories

# for base_path, batch_type in data_partition.items():
#     decompress_directories(
#         source_path=f'downloaded_data/{base_path}',
#         destination_path=f'data/{base_path}',
#         checksum_file_name="checksums.txt",
#         delete_compressed_files=True,
#         num_workers=26
#     )

In [3]:
from config import SimpleConvConfig
from models.simpleconv import SimpleConv
import torch

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=False,
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

In [4]:
from train.training_session_v0 import TrainingSessionV0
from config import TrainingConfigV0
import multiprocessing

config = TrainingConfigV0(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Pre-processing parameters
    # Brain
    new_freq=100,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=20,
    baseline_window=0.5,
    notch_filter=True,
    scaling="standard",
    delay=0.15,
    # Hyperparameters
    learning_rate=3e-4,
    weight_decay=1e-4,
    epochs=50,
    batch_size=256,
    use_clip_loss=True,
    use_mse_loss=True,
    alpha=0.6,
    random_test_size=10,
    seed=42,
)

session = TrainingSessionV0(
    config=config,
    studies={study: "audio" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase1/ablation/delay/0_15",
    clear_cache=False,
)

Loading Gwilliams2023 with batch type audio
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.

GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.
SimpleConv initialized with 8448160 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 6, hidden_dim: 384, params 7973376


In [5]:
try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=multiprocessing.cpu_count() - 2,
        max_cache_size=400,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

2024-12-20 20:02:39,957	INFO worker.py:1821 -- Started a local Ray instance.
Training Epoch 1: 100%|██████████| 135/135 [05:10<00:00,  2.30s/it]


Testing at epoch 1
Test unseen_subject completed. Accuracy: 0.0850, Top 1: 0.1380, Top 5: 0.3555, Top 10: 0.4875, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.0760, Top 1: 0.1123, Top 5: 0.3554, Top 10: 0.5241, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0663, Top 1: 0.1213, Top 5: 0.3387, Top 10: 0.4888, Perplexity: 0.0000
Epoch 1 completed in 5.56m. 0.04m per recording.


Training Epoch 2: 100%|██████████| 135/135 [04:19<00:00,  1.92s/it]


Testing at epoch 2
Test unseen_subject completed. Accuracy: 0.1150, Top 1: 0.1835, Top 5: 0.4300, Top 10: 0.5530, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1001, Top 1: 0.1472, Top 5: 0.4131, Top 10: 0.5891, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0900, Top 1: 0.1450, Top 5: 0.3862, Top 10: 0.5388, Perplexity: 0.0000
Epoch 2 completed in 4.67m. 0.03m per recording.


Training Epoch 3: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 3
Test unseen_subject completed. Accuracy: 0.1030, Top 1: 0.1765, Top 5: 0.3955, Top 10: 0.5270, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1099, Top 1: 0.1497, Top 5: 0.3793, Top 10: 0.5180, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0938, Top 1: 0.1425, Top 5: 0.3538, Top 10: 0.4688, Perplexity: 0.0000
Epoch 3 completed in 4.64m. 0.03m per recording.


Training Epoch 4: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 4
Test unseen_subject completed. Accuracy: 0.1350, Top 1: 0.2075, Top 5: 0.4275, Top 10: 0.5530, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1569, Top 1: 0.2127, Top 5: 0.5138, Top 10: 0.6638, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1288, Top 1: 0.2050, Top 5: 0.4362, Top 10: 0.5662, Perplexity: 0.0000
Epoch 4 completed in 4.65m. 0.03m per recording.


Training Epoch 5: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 5
Test unseen_subject completed. Accuracy: 0.1535, Top 1: 0.2290, Top 5: 0.4460, Top 10: 0.5530, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1595, Top 1: 0.2211, Top 5: 0.5160, Top 10: 0.6526, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1300, Top 1: 0.1938, Top 5: 0.4125, Top 10: 0.5287, Perplexity: 0.0000
Epoch 5 completed in 4.65m. 0.03m per recording.


Training Epoch 6: 100%|██████████| 135/135 [04:18<00:00,  1.92s/it]


Testing at epoch 6
Test unseen_subject completed. Accuracy: 0.1540, Top 1: 0.2410, Top 5: 0.4865, Top 10: 0.6105, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1953, Top 1: 0.2627, Top 5: 0.5571, Top 10: 0.6940, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1900, Top 1: 0.2775, Top 5: 0.4950, Top 10: 0.5862, Perplexity: 0.0000
Epoch 6 completed in 4.67m. 0.03m per recording.


Training Epoch 7: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 7
Test unseen_subject completed. Accuracy: 0.1405, Top 1: 0.2180, Top 5: 0.4770, Top 10: 0.5945, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1773, Top 1: 0.2469, Top 5: 0.5509, Top 10: 0.6923, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1438, Top 1: 0.2263, Top 5: 0.4637, Top 10: 0.5863, Perplexity: 0.0000
Epoch 7 completed in 4.65m. 0.03m per recording.


Training Epoch 8: 100%|██████████| 135/135 [04:19<00:00,  1.92s/it]


Testing at epoch 8
Test unseen_subject completed. Accuracy: 0.1915, Top 1: 0.2810, Top 5: 0.5175, Top 10: 0.6225, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1932, Top 1: 0.2653, Top 5: 0.5688, Top 10: 0.7167, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1713, Top 1: 0.2538, Top 5: 0.5125, Top 10: 0.6350, Perplexity: 0.0000
Epoch 8 completed in 4.68m. 0.03m per recording.


Training Epoch 9: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 9
Test unseen_subject completed. Accuracy: 0.1535, Top 1: 0.2335, Top 5: 0.4645, Top 10: 0.5790, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1997, Top 1: 0.2716, Top 5: 0.5710, Top 10: 0.6907, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1575, Top 1: 0.2337, Top 5: 0.4612, Top 10: 0.5713, Perplexity: 0.0000
Epoch 9 completed in 4.63m. 0.03m per recording.


Training Epoch 10: 100%|██████████| 135/135 [04:19<00:00,  1.92s/it]


Testing at epoch 10
Test unseen_subject completed. Accuracy: 0.1355, Top 1: 0.2120, Top 5: 0.4655, Top 10: 0.5970, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2073, Top 1: 0.2725, Top 5: 0.5561, Top 10: 0.6839, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1650, Top 1: 0.2363, Top 5: 0.4500, Top 10: 0.5750, Perplexity: 0.0000
Epoch 10 completed in 4.67m. 0.03m per recording.


Training Epoch 11: 100%|██████████| 135/135 [04:19<00:00,  1.92s/it]


Testing at epoch 11
Test unseen_subject completed. Accuracy: 0.1845, Top 1: 0.2710, Top 5: 0.5010, Top 10: 0.6125, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2166, Top 1: 0.2913, Top 5: 0.6034, Top 10: 0.7424, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1763, Top 1: 0.2500, Top 5: 0.4762, Top 10: 0.5850, Perplexity: 0.0000
Epoch 11 completed in 4.67m. 0.03m per recording.


Training Epoch 12: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 12
Test unseen_subject completed. Accuracy: 0.1945, Top 1: 0.2800, Top 5: 0.5035, Top 10: 0.6130, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2327, Top 1: 0.3127, Top 5: 0.6183, Top 10: 0.7477, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1875, Top 1: 0.2687, Top 5: 0.5150, Top 10: 0.6288, Perplexity: 0.0000
Epoch 12 completed in 4.66m. 0.03m per recording.


Training Epoch 13: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 13
Test unseen_subject completed. Accuracy: 0.1960, Top 1: 0.2880, Top 5: 0.5230, Top 10: 0.6445, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2216, Top 1: 0.2962, Top 5: 0.5924, Top 10: 0.7258, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1725, Top 1: 0.2537, Top 5: 0.4863, Top 10: 0.6100, Perplexity: 0.0000
Epoch 13 completed in 4.66m. 0.03m per recording.


Training Epoch 14: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 14
Test unseen_subject completed. Accuracy: 0.1950, Top 1: 0.2855, Top 5: 0.5310, Top 10: 0.6285, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2107, Top 1: 0.2831, Top 5: 0.5888, Top 10: 0.7296, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1800, Top 1: 0.2538, Top 5: 0.4750, Top 10: 0.6000, Perplexity: 0.0000
Epoch 14 completed in 4.64m. 0.03m per recording.


Training Epoch 15: 100%|██████████| 135/135 [04:19<00:00,  1.92s/it]


Testing at epoch 15
Test unseen_subject completed. Accuracy: 0.1685, Top 1: 0.2565, Top 5: 0.4860, Top 10: 0.6115, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2421, Top 1: 0.3220, Top 5: 0.6258, Top 10: 0.7513, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1963, Top 1: 0.2675, Top 5: 0.5200, Top 10: 0.6350, Perplexity: 0.0000
Epoch 15 completed in 4.68m. 0.03m per recording.


Training Epoch 16: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 16
Test unseen_subject completed. Accuracy: 0.1970, Top 1: 0.2900, Top 5: 0.5270, Top 10: 0.6265, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2265, Top 1: 0.3013, Top 5: 0.6168, Top 10: 0.7496, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.2075, Top 1: 0.2750, Top 5: 0.4925, Top 10: 0.6038, Perplexity: 0.0000
Epoch 16 completed in 4.65m. 0.03m per recording.


Training Epoch 17: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 17
Test unseen_subject completed. Accuracy: 0.1200, Top 1: 0.1955, Top 5: 0.4450, Top 10: 0.5725, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1579, Top 1: 0.2133, Top 5: 0.5053, Top 10: 0.6501, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1187, Top 1: 0.1925, Top 5: 0.4213, Top 10: 0.5600, Perplexity: 0.0000
Epoch 17 completed in 4.66m. 0.03m per recording.


Training Epoch 18: 100%|██████████| 135/135 [04:19<00:00,  1.92s/it]


Testing at epoch 18
Test unseen_subject completed. Accuracy: 0.1650, Top 1: 0.2435, Top 5: 0.4770, Top 10: 0.5975, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2105, Top 1: 0.2824, Top 5: 0.5778, Top 10: 0.7046, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1562, Top 1: 0.2238, Top 5: 0.4425, Top 10: 0.5725, Perplexity: 0.0000
Epoch 18 completed in 4.67m. 0.03m per recording.


Training Epoch 19: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 19
Test unseen_subject completed. Accuracy: 0.1690, Top 1: 0.2565, Top 5: 0.4800, Top 10: 0.5865, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2203, Top 1: 0.2869, Top 5: 0.5801, Top 10: 0.7136, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1975, Top 1: 0.2688, Top 5: 0.4750, Top 10: 0.5800, Perplexity: 0.0000
Epoch 19 completed in 4.64m. 0.03m per recording.


Training Epoch 20: 100%|██████████| 135/135 [04:19<00:00,  1.92s/it]


Testing at epoch 20
Test unseen_subject completed. Accuracy: 0.1380, Top 1: 0.2230, Top 5: 0.4795, Top 10: 0.5975, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2031, Top 1: 0.2706, Top 5: 0.5873, Top 10: 0.7168, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1675, Top 1: 0.2400, Top 5: 0.4587, Top 10: 0.5775, Perplexity: 0.0000
Epoch 20 completed in 4.68m. 0.03m per recording.


Training Epoch 21: 100%|██████████| 135/135 [04:19<00:00,  1.92s/it]


Testing at epoch 21
Test unseen_subject completed. Accuracy: 0.1680, Top 1: 0.2600, Top 5: 0.4725, Top 10: 0.5760, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1873, Top 1: 0.2539, Top 5: 0.5467, Top 10: 0.6835, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1850, Top 1: 0.2350, Top 5: 0.4388, Top 10: 0.5363, Perplexity: 0.0000
Epoch 21 completed in 4.68m. 0.03m per recording.


Training Epoch 22: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 22
Test unseen_subject completed. Accuracy: 0.1465, Top 1: 0.2275, Top 5: 0.4435, Top 10: 0.5625, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1664, Top 1: 0.2259, Top 5: 0.5154, Top 10: 0.6564, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1562, Top 1: 0.2300, Top 5: 0.4313, Top 10: 0.5388, Perplexity: 0.0000
Epoch 22 completed in 4.64m. 0.03m per recording.


Training Epoch 23: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 23
Test unseen_subject completed. Accuracy: 0.1285, Top 1: 0.2085, Top 5: 0.4540, Top 10: 0.5660, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1506, Top 1: 0.2057, Top 5: 0.4926, Top 10: 0.6411, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1425, Top 1: 0.2000, Top 5: 0.4187, Top 10: 0.5200, Perplexity: 0.0000
Epoch 23 completed in 4.64m. 0.03m per recording.


Training Epoch 24: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 24
Test unseen_subject completed. Accuracy: 0.1125, Top 1: 0.1725, Top 5: 0.3920, Top 10: 0.5155, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1136, Top 1: 0.1626, Top 5: 0.4182, Top 10: 0.5698, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1225, Top 1: 0.1787, Top 5: 0.3625, Top 10: 0.4825, Perplexity: 0.0000
Epoch 24 completed in 4.66m. 0.03m per recording.


Training Epoch 25: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 25
Test unseen_subject completed. Accuracy: 0.1025, Top 1: 0.1780, Top 5: 0.4190, Top 10: 0.5430, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1257, Top 1: 0.1790, Top 5: 0.4492, Top 10: 0.6127, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1025, Top 1: 0.1562, Top 5: 0.3425, Top 10: 0.4775, Perplexity: 0.0000
Epoch 25 completed in 4.66m. 0.03m per recording.


Training Epoch 26: 100%|██████████| 135/135 [04:18<00:00,  1.92s/it]


Testing at epoch 26
Test unseen_subject completed. Accuracy: 0.1090, Top 1: 0.1860, Top 5: 0.4015, Top 10: 0.5255, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1331, Top 1: 0.1853, Top 5: 0.4535, Top 10: 0.6015, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1275, Top 1: 0.1850, Top 5: 0.3862, Top 10: 0.4938, Perplexity: 0.0000
Epoch 26 completed in 4.66m. 0.03m per recording.
Early stopping at epoch 26. Highest top 10 accuracy at epoch 15.
Training completed.
unseen_subject: Acc: 0.1685, Top 1: 0.2565, Top 5: 0.4860, Top 10: 0.6115
unseen_task: Acc: 0.2421, Top 1: 0.3220, Top 5: 0.6258, Top 10: 0.7513
unseen_both: Acc: 0.1963, Top 1: 0.2675, Top 5: 0.5200, Top 10: 0.6350


In [6]:
session.highest_epoch

15