In [1]:
# Place this as the FIRST cell, before importing torch.
import os, random
import numpy as np

SEED = 42

# For Python determinism
os.environ["PYTHONHASHSEED"] = str(SEED)
# Deterministic cuBLAS (required for some CUDA matmul ops)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import torch

# Seed Python, NumPy, Torch (CPU and CUDA)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Turn on deterministic behavior
torch.use_deterministic_algorithms(True)  # may raise on nondeterministic ops
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Keep math consistent
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

torch.autograd.set_detect_anomaly(True)

# Optional: remove threading non-determinism
# torch.set_num_threads(1)

# Helpers for DataLoader reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

  self.setter(val)


<torch._C.Generator at 0x7e6fea6b79b0>

In [2]:
from loguru import logger
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Subset
from src.datasets.seeg_dataset import SEEGDataset
from src.models.model import SEEGFusionModel
from src.training.train import train_model
from src.training.evaluate import evaluate_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cpu


In [3]:
def compute_class_weights(train_ds):
    labels = np.array([v[1] for v in train_ds])
    class_sample_count = np.array(
        [len(np.where(labels == t)[0]) for t in np.unique(labels)])
    weight = class_sample_count.sum() / class_sample_count
    return torch.from_numpy(weight).float()

subjects = ['Epat26', 'Epat30', 'Epat31', 'Epat34', 'Epat35', 'Epat37']

# Helper to get indices for specific subjects
def get_subject_indices(dataset, subj_list):
    return [i for i, s in enumerate(dataset.data) if s['subject'] in subj_list]

# Create dataset once (loads all subjects)
full_dataset = SEEGDataset(subjects=subjects)

[32m2025-11-07 10:09:51.364[0m | [32m[1mSUCCESS [0m | [36msrc.datasets.seeg_dataset[0m:[36m__init__[0m:[36m128[0m - [32m[1mâœ… Loaded 279 total samples from 6 subjects.[0m


In [None]:
# Example: LOPO outer loop
for test_subj in subjects:
    logger.info(f"\n=== Test subject: {test_subj} ===")
    remaining_subjs = [s for s in subjects if s != test_subj]

    # Outer split: test vs remaining
    test_idx = get_subject_indices(full_dataset, [test_subj])
    test_ds = Subset(full_dataset, test_idx)

    # Inner split subjects (for hyperparam tuning)
    # Shuffle remaining subjects so different folds vary
    random.shuffle(remaining_subjs)

    # Do 5 different inner splits (8 train / 2 val)
    inner_splits = []
    for i in range(5):
        # rotate subjects for different validation sets
        val_subjs = remaining_subjs[i*2:(i+1)*2] if i*2 < len(remaining_subjs) else remaining_subjs[-2:]
        train_subjs = [s for s in remaining_subjs if s not in val_subjs]
        inner_splits.append((train_subjs, val_subjs))

    # Run inner CV for this test subject
    for k, (train_subjs, val_subjs) in enumerate(inner_splits):
        logger.info(f"  Inner split {k+1}: train={train_subjs}, val={val_subjs}")

        train_idx = get_subject_indices(full_dataset, train_subjs)
        val_idx = get_subject_indices(full_dataset, val_subjs)

        train_ds = Subset(full_dataset, train_idx)
        val_ds = Subset(full_dataset, val_idx)

        dataloaders = {
            'train': DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0, worker_init_fn=seed_worker, generator=g),
            'val': DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g),
            'test': DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g)
        }

        weights = compute_class_weights(train_ds)

        print(next(iter(dataloaders['train']))[0]['convergent'].shape)
        print(next(iter(dataloaders['train']))[0]['divergent'].shape)

        model = SEEGFusionModel(embed_dim=128, n_classes=2, device=device)
        model.to(device)
        optimizer = optim.AdamW(model.parameters(), lr=1e-5, )
        criterion = nn.CrossEntropyLoss(weight=weights.to(device))

        model, history, best_epoch = train_model(
            model=model,
            dataloaders=dataloaders,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            save_prefix=f'{test_subj}_split_{k}',
            n_epochs=15,
            patience=2,
        )

In [None]:
subset_loader = torch.utils.data.DataLoader(
    torch.utils.data.Subset(full_dataset, range(16)), 
    batch_size=8, shuffle=True
)
weights = compute_class_weights(subset_loader)
model = SEEGFusionModel(embed_dim=128, n_classes=2, device=device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-4,
    steps_per_epoch=len(subset_loader),
    epochs=10,
    pct_start=0.1,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1e4,
)
criterion = nn.CrossEntropyLoss(weight=weights.to(device))

train_model(model, {'train':subset_loader, 'val': subset_loader}, criterion, optimizer, scheduler, device, save_prefix='testing', n_epochs=10)

[32m2025-11-07 10:12:57.516[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m41[0m - [1m
Starting training for 10 epochs on device: cpu[0m




[32m2025-11-07 10:14:49.737[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m103[0m - [1m
Epoch 1/10 Summary:[0m
[32m2025-11-07 10:14:49.737[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m104[0m - [1m  Train Loss: 0.6779 | Train Acc: 0.7500[0m
[32m2025-11-07 10:14:49.737[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m105[0m - [1m  Val Loss:   0.6990 | Val Acc:   0.2500[0m
[32m2025-11-07 10:14:49.737[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m106[0m - [1m  Time: 112.22s[0m




[32m2025-11-07 10:16:40.487[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m103[0m - [1m
Epoch 2/10 Summary:[0m
[32m2025-11-07 10:16:40.488[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m104[0m - [1m  Train Loss: 0.6645 | Train Acc: 0.6250[0m
[32m2025-11-07 10:16:40.488[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m105[0m - [1m  Val Loss:   0.6660 | Val Acc:   0.7500[0m
[32m2025-11-07 10:16:40.488[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m106[0m - [1m  Time: 110.67s[0m




[32m2025-11-07 10:18:31.958[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m103[0m - [1m
Epoch 3/10 Summary:[0m
[32m2025-11-07 10:18:31.959[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m104[0m - [1m  Train Loss: 0.7145 | Train Acc: 0.5625[0m
[32m2025-11-07 10:18:31.959[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m105[0m - [1m  Val Loss:   0.6520 | Val Acc:   0.7500[0m
[32m2025-11-07 10:18:31.959[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m106[0m - [1m  Time: 111.40s[0m




[32m2025-11-07 10:20:23.001[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m103[0m - [1m
Epoch 4/10 Summary:[0m
[32m2025-11-07 10:20:23.001[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m104[0m - [1m  Train Loss: 0.7145 | Train Acc: 0.5000[0m
[32m2025-11-07 10:20:23.001[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m105[0m - [1m  Val Loss:   0.6794 | Val Acc:   0.7500[0m
[32m2025-11-07 10:20:23.001[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m106[0m - [1m  Time: 110.98s[0m




                                                      

KeyboardInterrupt: 