In [1]:
data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "schoffelen2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

In [2]:
# from utils.compression import compress_directories, decompress_directories

# for base_path, batch_type in data_partition.items():
#     decompress_directories(
#         source_path=f'downloaded_data/{base_path}',
#         destination_path=f'data/{base_path}',
#         checksum_file_name="checksums.txt",
#         delete_compressed_files=True,
#         num_workers=26
#     )

In [3]:
from config import SimpleConvConfig
from models.simpleconv import SimpleConv
import torch

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=False,
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input=None,
    transformer_encoder_emb=None,
    transformer_encoder_layers=0,
    transformer_encoder_heads=0,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

In [4]:
from train.training_session_v0 import TrainingSessionV0
from config import TrainingConfigV0
import multiprocessing

config = TrainingConfigV0(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Pre-processing parameters
    # Brain
    new_freq=100,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=20,
    baseline_window=0.5,
    notch_filter=True,
    scaling="standard",
    delay=0.0,
    # Hyperparameters
    learning_rate=3e-4,
    weight_decay=1e-4,
    epochs=50,
    batch_size=256,
    use_clip_loss=True,
    use_mse_loss=True,
    alpha=0.6,
    random_test_size=10,
    seed=42,
)

session = TrainingSessionV0(
    config=config,
    studies={study: "audio" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase1/ablation/delay/0_0",
    clear_cache=False,
)

Loading Gwilliams2023 with batch type audio
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.

GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.
SimpleConv initialized with 8448160 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 6, hidden_dim: 384, params 7973376


In [None]:
try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=multiprocessing.cpu_count() - 2,
        max_cache_size=400,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

2024-12-20 17:53:58,294	INFO worker.py:1821 -- Started a local Ray instance.
Training Epoch 1: 100%|██████████| 135/135 [25:21<00:00, 11.27s/it] 


Testing at epoch 1
Test unseen_subject completed. Accuracy: 0.0380, Top 1: 0.0690, Top 5: 0.2625, Top 10: 0.4135, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.0503, Top 1: 0.0761, Top 5: 0.3019, Top 10: 0.4707, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0488, Top 1: 0.0887, Top 5: 0.2975, Top 10: 0.4337, Perplexity: 0.0000
Epoch 1 completed in 29.41m. 0.22m per recording.


Training Epoch 2: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 2
Test unseen_subject completed. Accuracy: 0.1025, Top 1: 0.1740, Top 5: 0.3950, Top 10: 0.5230, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.0825, Top 1: 0.1220, Top 5: 0.3541, Top 10: 0.5091, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0575, Top 1: 0.1013, Top 5: 0.3087, Top 10: 0.4338, Perplexity: 0.0000
Epoch 2 completed in 4.62m. 0.03m per recording.


Training Epoch 3: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 3
Test unseen_subject completed. Accuracy: 0.0890, Top 1: 0.1570, Top 5: 0.3705, Top 10: 0.4830, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1386, Top 1: 0.1904, Top 5: 0.4707, Top 10: 0.6209, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1013, Top 1: 0.1550, Top 5: 0.3875, Top 10: 0.5125, Perplexity: 0.0000
Epoch 3 completed in 4.65m. 0.03m per recording.


Training Epoch 4: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 4
Test unseen_subject completed. Accuracy: 0.0550, Top 1: 0.0960, Top 5: 0.3045, Top 10: 0.4680, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.0940, Top 1: 0.1336, Top 5: 0.4007, Top 10: 0.5668, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0762, Top 1: 0.1200, Top 5: 0.3275, Top 10: 0.4850, Perplexity: 0.0000
Epoch 4 completed in 4.64m. 0.03m per recording.


Training Epoch 5: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 5
Test unseen_subject completed. Accuracy: 0.0830, Top 1: 0.1385, Top 5: 0.3735, Top 10: 0.5220, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1424, Top 1: 0.1946, Top 5: 0.4802, Top 10: 0.6348, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1175, Top 1: 0.1850, Top 5: 0.4088, Top 10: 0.5525, Perplexity: 0.0000
Epoch 5 completed in 4.64m. 0.03m per recording.


Training Epoch 6: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 6
Test unseen_subject completed. Accuracy: 0.0830, Top 1: 0.1255, Top 5: 0.3135, Top 10: 0.4420, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1335, Top 1: 0.1886, Top 5: 0.4413, Top 10: 0.5794, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1000, Top 1: 0.1362, Top 5: 0.3175, Top 10: 0.4275, Perplexity: 0.0000
Epoch 6 completed in 4.65m. 0.03m per recording.


Training Epoch 7: 100%|██████████| 135/135 [04:17<00:00,  1.90s/it]


Testing at epoch 7
Test unseen_subject completed. Accuracy: 0.1335, Top 1: 0.2050, Top 5: 0.4485, Top 10: 0.5625, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1857, Top 1: 0.2498, Top 5: 0.5450, Top 10: 0.7010, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1400, Top 1: 0.2400, Top 5: 0.4663, Top 10: 0.5925, Perplexity: 0.0000
Epoch 7 completed in 4.63m. 0.03m per recording.


Training Epoch 8: 100%|██████████| 135/135 [04:15<00:00,  1.89s/it]


Testing at epoch 8
Test unseen_subject completed. Accuracy: 0.1175, Top 1: 0.1915, Top 5: 0.4385, Top 10: 0.5715, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1576, Top 1: 0.2122, Top 5: 0.5008, Top 10: 0.6491, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1075, Top 1: 0.1787, Top 5: 0.3987, Top 10: 0.5325, Perplexity: 0.0000
Epoch 8 completed in 4.61m. 0.03m per recording.


Training Epoch 9: 100%|██████████| 135/135 [04:16<00:00,  1.90s/it]


Testing at epoch 9
Test unseen_subject completed. Accuracy: 0.0805, Top 1: 0.1345, Top 5: 0.3620, Top 10: 0.5105, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1317, Top 1: 0.1839, Top 5: 0.4587, Top 10: 0.6101, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.0975, Top 1: 0.1588, Top 5: 0.4112, Top 10: 0.5387, Perplexity: 0.0000
Epoch 9 completed in 4.63m. 0.03m per recording.


Training Epoch 10: 100%|██████████| 135/135 [04:17<00:00,  1.90s/it]


Testing at epoch 10
Test unseen_subject completed. Accuracy: 0.1425, Top 1: 0.2190, Top 5: 0.4805, Top 10: 0.6105, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1958, Top 1: 0.2706, Top 5: 0.5816, Top 10: 0.7199, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1512, Top 1: 0.2200, Top 5: 0.4675, Top 10: 0.5725, Perplexity: 0.0000
Epoch 10 completed in 4.63m. 0.03m per recording.


Training Epoch 11: 100%|██████████| 135/135 [04:17<00:00,  1.90s/it]


Testing at epoch 11
Test unseen_subject completed. Accuracy: 0.1685, Top 1: 0.2490, Top 5: 0.4895, Top 10: 0.6070, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1806, Top 1: 0.2469, Top 5: 0.5382, Top 10: 0.6784, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1413, Top 1: 0.2225, Top 5: 0.4475, Top 10: 0.5425, Perplexity: 0.0000
Epoch 11 completed in 4.63m. 0.03m per recording.


Training Epoch 12: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 12
Test unseen_subject completed. Accuracy: 0.1575, Top 1: 0.2305, Top 5: 0.4630, Top 10: 0.5825, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2182, Top 1: 0.2907, Top 5: 0.5840, Top 10: 0.7140, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1725, Top 1: 0.2350, Top 5: 0.4662, Top 10: 0.5863, Perplexity: 0.0000
Epoch 12 completed in 4.66m. 0.03m per recording.


Training Epoch 13: 100%|██████████| 135/135 [04:17<00:00,  1.90s/it]


Testing at epoch 13
Test unseen_subject completed. Accuracy: 0.1750, Top 1: 0.2610, Top 5: 0.5015, Top 10: 0.6115, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1981, Top 1: 0.2675, Top 5: 0.5534, Top 10: 0.6953, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1562, Top 1: 0.2200, Top 5: 0.4750, Top 10: 0.5763, Perplexity: 0.0000
Epoch 13 completed in 4.64m. 0.03m per recording.


Training Epoch 14: 100%|██████████| 135/135 [04:17<00:00,  1.90s/it]


Testing at epoch 14
Test unseen_subject completed. Accuracy: 0.0930, Top 1: 0.1630, Top 5: 0.4100, Top 10: 0.5540, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1508, Top 1: 0.2094, Top 5: 0.4998, Top 10: 0.6547, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1162, Top 1: 0.2075, Top 5: 0.4375, Top 10: 0.5500, Perplexity: 0.0000
Epoch 14 completed in 4.63m. 0.03m per recording.


Training Epoch 15: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 15
Test unseen_subject completed. Accuracy: 0.1665, Top 1: 0.2575, Top 5: 0.4990, Top 10: 0.5985, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2174, Top 1: 0.2897, Top 5: 0.5842, Top 10: 0.7159, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1588, Top 1: 0.2462, Top 5: 0.4700, Top 10: 0.5763, Perplexity: 0.0000
Epoch 15 completed in 4.63m. 0.03m per recording.


Training Epoch 16: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 16
Test unseen_subject completed. Accuracy: 0.1630, Top 1: 0.2410, Top 5: 0.4710, Top 10: 0.5890, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.2160, Top 1: 0.2842, Top 5: 0.5756, Top 10: 0.7167, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1450, Top 1: 0.2275, Top 5: 0.4600, Top 10: 0.5763, Perplexity: 0.0000
Epoch 16 completed in 4.64m. 0.03m per recording.


Training Epoch 17: 100%|██████████| 135/135 [04:17<00:00,  1.91s/it]


Testing at epoch 17
Test unseen_subject completed. Accuracy: 0.1790, Top 1: 0.2545, Top 5: 0.5025, Top 10: 0.6100, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1861, Top 1: 0.2516, Top 5: 0.5596, Top 10: 0.6959, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1450, Top 1: 0.2150, Top 5: 0.4488, Top 10: 0.5600, Perplexity: 0.0000
Epoch 17 completed in 4.64m. 0.03m per recording.


Training Epoch 18: 100%|██████████| 135/135 [04:18<00:00,  1.91s/it]


Testing at epoch 18
Test unseen_subject completed. Accuracy: 0.1755, Top 1: 0.2530, Top 5: 0.4820, Top 10: 0.5805, Perplexity: 0.0000
Test unseen_task completed. Accuracy: 0.1685, Top 1: 0.2303, Top 5: 0.5227, Top 10: 0.6659, Perplexity: 0.0000
Test unseen_both completed. Accuracy: 0.1162, Top 1: 0.1825, Top 5: 0.4262, Top 10: 0.5375, Perplexity: 0.0000
Epoch 18 completed in 4.66m. 0.03m per recording.


Training Epoch 19:  43%|████▎     | 58/135 [02:13<02:33,  2.00s/it]

In [6]:
session.highest_epoch

14