In [1]:
# Place this as the FIRST cell, before importing torch.
import os, random
import numpy as np

SEED = 42

# For Python determinism
os.environ["PYTHONHASHSEED"] = str(SEED)
# Deterministic cuBLAS (required for some CUDA matmul ops)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import torch

# Seed Python, NumPy, Torch (CPU and CUDA)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Turn on deterministic behavior
torch.use_deterministic_algorithms(True)  # may raise on nondeterministic ops
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Keep math consistent
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

torch.autograd.set_detect_anomaly(True)

# Optional: remove threading non-determinism
# torch.set_num_threads(1)

# Helpers for DataLoader reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

  self.setter(val)


<torch._C.Generator at 0x7dfd500837b0>

In [None]:
from loguru import logger
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Subset
from src.datasets.seeg_dataset import SEEGDataset
from src.models.model import SEEGFusionModel
from src.training.train import train_model
from src.training.evaluate import evaluate_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cpu


In [None]:
def compute_class_weights(train_ds):
    labels = np.array([v[1] for v in train_ds])
    class_sample_count = np.array(
        [len(np.where(labels == t)[0]) for t in np.unique(labels)])
    weight = class_sample_count.sum() / class_sample_count
    return torch.from_numpy(weight).float()

subjects = ['Epat26', 'Epat30', 'Epat31', 'Epat34', 'Epat35', 'Epat37']

# Helper to get indices for specific subjects
def get_subject_indices(dataset, subj_list):
    return [i for i, s in enumerate(dataset.data) if s['subject'] in subj_list]

# Example: LOPO outer loop
for test_subj in subjects:
    logger.info(f"\n=== Test subject: {test_subj} ===")
    remaining_subjs = [s for s in subjects if s != test_subj]

    # Create dataset once (loads all subjects)
    full_dataset = SEEGDataset(subjects=subjects)

    # Outer split: test vs remaining
    test_idx = get_subject_indices(full_dataset, [test_subj])
    test_ds = Subset(full_dataset, test_idx)

    # Inner split subjects (for hyperparam tuning)
    # Shuffle remaining subjects so different folds vary
    random.shuffle(remaining_subjs)

    # Do 5 different inner splits (8 train / 2 val)
    inner_splits = []
    for i in range(5):
        # rotate subjects for different validation sets
        val_subjs = remaining_subjs[i*2:(i+1)*2] if i*2 < len(remaining_subjs) else remaining_subjs[-2:]
        train_subjs = [s for s in remaining_subjs if s not in val_subjs]
        inner_splits.append((train_subjs, val_subjs))

    # Run inner CV for this test subject
    for k, (train_subjs, val_subjs) in enumerate(inner_splits):
        logger.info(f"  Inner split {k+1}: train={train_subjs}, val={val_subjs}")

        train_idx = get_subject_indices(full_dataset, train_subjs)
        val_idx = get_subject_indices(full_dataset, val_subjs)

        train_ds = Subset(full_dataset, train_idx)
        val_ds = Subset(full_dataset, val_idx)

        dataloaders = {
            'train': DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0, worker_init_fn=seed_worker, generator=g),
            'val': DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g),
            'test': DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g)
        }

        weights = compute_class_weights(train_ds)

        print(next(iter(dataloaders['train']))[0]['convergent'].shape)
        print(next(iter(dataloaders['train']))[0]['divergent'].shape)

        model = SEEGFusionModel(embed_dim=128, n_classes=2, device=device)
        model.to(device)
        optimizer = optim.AdamW(model.parameters(), lr=1e-6)
        criterion = nn.CrossEntropyLoss(weight=weights.to(device))

        model, history, best_epoch = train_model(
            model=model,
            dataloaders=dataloaders,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            save_prefix=f'{test_subj}_split_{k}',
            n_epochs=15,
            patience=2,
        )


=== Test subject: Epat26 ===


[32m2025-11-06 20:59:09.003[0m | [32m[1mSUCCESS [0m | [36msrc.datasets.seeg_dataset[0m:[36m__init__[0m:[36m128[0m - [32m[1m✅ Loaded 279 total samples from 6 subjects.[0m


  Inner split 1: train=['Epat34', 'Epat37', 'Epat30'], val=['Epat35', 'Epat31']
torch.Size([16, 70, 50, 487])
torch.Size([16, 124, 50, 487])


[32m2025-11-06 20:59:09.466[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m41[0m - [1m
Starting training for 15 epochs on device: cpu[0m
Epoch 1 [train]:   0%|          | 0/11 [00:00<?, ?it/s]

resnet_conv_output: mean -2.253e-02, std 4.492e-01, min -1.292e+01, max 1.035e+01


Epoch 1 [train]:   9%|▉         | 1/11 [01:17<12:52, 77.29s/it]

resnet_conv_output: mean -2.199e-02, std 4.463e-01, min -9.420e+00, max 7.459e+00


Epoch 1 [train]:  18%|█▊        | 2/11 [02:37<11:50, 78.97s/it]

resnet_conv_output: mean -1.040e-02, std 3.972e-01, min -7.021e+01, max 5.588e+01


Epoch 1 [train]:  27%|██▋       | 3/11 [03:55<10:29, 78.71s/it]

resnet_conv_output: mean -1.650e-02, std 3.813e-01, min -7.656e+01, max 7.907e+01
