In [1]:
# del session.logger
# del session

import multiprocessing


from train.training_session_v0 import TrainingSessionV0
from config import TrainingConfigV0
from config import SimpleConvConfig

data_partition = {
    "gwilliams2023": {
        "testing_subjects": [19, 20, 21],
        "testing_tasks": [0],
    },
    # "armeini2022": {
    #     "testing_subjects": [],
    #     "testing_tasks": [8, 9],
    # },
}

# model_config = SimpleConvConfig(
#     # Str to list of possible conditions
#     mel_normalization=False,
#     conditions={
#         "study": [],
#         "subject": [],
#     },
#     # Channels
#     in_channels=208,
#     out_channels=128,
#     hidden_dim=384,
#     dropout=0.2,
#     initial_batch_norm=True,
#     # Sensor layout settings
#     layout_dim=2,
#     layout_proj=True,
#     layout_scaling="minmax",
#     # Merger with spatial attn
#     merger=False,
#     merger_emb_type=None,
#     merger_emb_dim=0,
#     merger_channels=0,
#     merger_dropout=False,
#     merger_conditional=None,
#     # Inital
#     initial_linear=384,
#     initial_depth=1,
#     # Conditional layers
#     conditional_layers=False,
#     conditional_layers_dim=None,  # input or hidden_dim
#     # Conv layer overall structure
#     depth=6,
#     kernel_size=3,
#     growth=1.0,
#     dilation_growth=2,
#     dilation_period=5,
#     glu=1,
#     conv_dropout=0.2,
#     dropout_input=0.2,
#     batch_norm=True,
#     half=True,
#     cnn_pos_encoding=False,
#     # Quantizer
#     quantizer=False,
#     num_codebooks=0,
#     codebook_size=0,
#     quantizer_commitment=0,
#     quantizer_temp_init=0,
#     quantizer_temp_min=0,
#     quantizer_temp_decay=0,
#     # Transformers Encoders
#     transformer_input=None,
#     transformer_encoder_emb=None,
#     transformer_encoder_layers=0,
#     transformer_encoder_heads=0,
#     # Transformer Decoders
#     transformer_decoder_emb=None,
#     transformer_decoder_layers=0,
#     transformer_decoder_heads=0,
#     transformer_decoder_dim=0,
# )

model_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=128,
    hidden_dim=384,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=True,
    merger_emb_type='linear',
    merger_emb_dim=2048,
    merger_channels=256,
    merger_dropout=0.2, # Float
    merger_conditional=None,
    # Inital
    initial_linear=384,
    initial_depth=1,
    # Conditional layers
    conditional_layers=True,
    conditional_layers_dim='input',  # input or hidden_dim
    # Conv layer overall structure
    depth=6,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.2,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=True,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    # Transformers Encoders
    transformer_input="continuous",  # concat or quantized or continuous
    transformer_encoder_emb = "groupconv",
    transformer_encoder_layers = 4,
    transformer_encoder_heads = 8,
    # Conformer encoder variant
    rnn_type = "transformer",
    depthwise_conv_kernel_size = 31,
    use_group_norm = True,
    convolution_first = False,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

config = TrainingConfigV0(
    brain_encoder_config=model_config,
    data_partition=data_partition,
    # Pre-processing parameters
    # Brain
    new_freq=200,
    frequency_bands={"all": (0.5, 80)},
    max_random_shift=1.0,
    window_size=4,
    window_stride=1,
    brain_clipping=None,
    baseline_window=0.5,
    notch_filter=True,
    scaling="both",
    delay=0.15,
    # Hyperparameters
    learning_rate=5e-4,
    weight_decay=1e-4,
    epochs=50,
    batch_size=256,
    use_clip_loss=True,
    use_mse_loss=True,
    alpha=0.6,
    random_test_size=10,
    seed=42,
)

config.learning_rate = 3e-4
config.batch_size = 256

session = TrainingSessionV0(
    config=config,
    studies={study: "audio" for study in data_partition.keys()},
    data_path="data",
    save_path="saves/phase1/objectives/CLIP_MSE_full_model",
    clear_cache=False,
    cache_name="cache/1",
)

try:
    session.train(
        device="cuda",
        buffer_size=30,
        num_workers=(multiprocessing.cpu_count() - 2),
        max_cache_size=400,
        current_epoch=0,
    )
except KeyboardInterrupt as e:
    print("Exited")

Loading Gwilliams2023 with batch type audio
Data partitioned on studies ['gwilliams2023'].
Train: 135, Unseen Task: 12, Unseen Subject: 45, Unseen Both: 4.

Conditional layer study initialized with 2 conditions
Conditional layer subject initialized with 28 conditions
RNNEncoder initialized as transformer with 4 layers, 384 d_model, 8 nhead
	Embedding: groupconv, params: 7099776
SimpleConv initialized with 16986880 parameters, cond: ['study', 'subject']
Merger True, merger channels 256
ConvBlocks: 6, hidden_dim: 384, params 8858496


2025-01-26 16:05:25,069	INFO worker.py:1841 -- Started a local Ray instance.
Training Epoch 1:   0%|          | 0/135 [00:00<?, ?it/s]

1. scores: torch.Size([256, 256, 208]), max: 4.875, min: -4.21875
2. scores: torch.Size([256, 256, 208]), max: 4.84375, min: -inf
1. scores: torch.Size([170, 256, 208]), max: 4.8125, min: -4.125
2. scores: torch.Size([170, 256, 208]), max: 4.8125, min: -inf


Training Epoch 1:   1%|          | 1/135 [00:43<1:38:15, 43.99s/it]

1. scores: torch.Size([256, 256, 208]), max: 4.875, min: -4.125
2. scores: torch.Size([256, 256, 208]), max: 4.875, min: -inf
1. scores: torch.Size([170, 256, 208]), max: 4.75, min: -4.34375
2. scores: torch.Size([170, 256, 208]), max: 4.75, min: -inf


Training Epoch 1:   1%|▏         | 2/135 [00:46<43:04, 19.43s/it]  

1. scores: torch.Size([256, 256, 208]), max: 4.71875, min: -4.375
2. scores: torch.Size([256, 256, 208]), max: 4.71875, min: -inf
1. scores: torch.Size([170, 256, 208]), max: 4.71875, min: -4.4375
2. scores: torch.Size([170, 256, 208]), max: 4.71875, min: -inf


Training Epoch 1:   2%|▏         | 3/135 [00:48<25:25, 11.55s/it]

1. scores: torch.Size([256, 256, 208]), max: 4.6875, min: -4.625
2. scores: torch.Size([256, 256, 208]), max: 4.6875, min: -inf
1. scores: torch.Size([170, 256, 208]), max: 4.71875, min: -4.75
2. scores: torch.Size([170, 256, 208]), max: 4.71875, min: -inf


Training Epoch 1:   3%|▎         | 4/135 [00:50<16:50,  7.71s/it]

1. scores: torch.Size([256, 256, 208]), max: 4.75, min: -4.90625
2. scores: torch.Size([256, 256, 208]), max: 4.75, min: -inf
1. scores: torch.Size([170, 256, 208]), max: 4.78125, min: -5.0625
2. scores: torch.Size([170, 256, 208]), max: 4.78125, min: -inf


Training Epoch 1:   4%|▎         | 5/135 [00:52<12:46,  5.90s/it]

1. scores: torch.Size([256, 256, 208]), max: 4.78125, min: -5.1875
2. scores: torch.Size([256, 256, 208]), max: 4.78125, min: -inf
1. scores: torch.Size([256, 256, 208]), max: 4.8125, min: -5.28125
2. scores: torch.Size([256, 256, 208]), max: 4.8125, min: -inf
1. scores: torch.Size([226, 256, 208]), max: 4.8125, min: -5.375
2. scores: torch.Size([226, 256, 208]), max: 4.8125, min: -inf


Training Epoch 1:   4%|▍         | 6/135 [00:55<10:26,  4.86s/it]

1. scores: torch.Size([256, 256, 208]), max: 4.84375, min: -5.46875
2. scores: torch.Size([256, 256, 208]), max: 4.84375, min: -inf
1. scores: torch.Size([170, 256, 208]), max: 4.84375, min: -5.5625
2. scores: torch.Size([170, 256, 208]), max: 4.84375, min: -inf


Training Epoch 1:   5%|▌         | 7/135 [00:58<08:41,  4.07s/it]

1. scores: torch.Size([256, 256, 208]), max: 4.875, min: -5.625
2. scores: torch.Size([256, 256, 208]), max: 4.875, min: -inf
1. scores: torch.Size([256, 256, 208]), max: 4.90625, min: -5.6875
2. scores: torch.Size([256, 256, 208]), max: 4.90625, min: -inf


Training Epoch 1:   5%|▌         | 7/135 [01:00<18:28,  8.66s/it]

Exited



