In [None]:
# Place this as the FIRST cell, before importing torch.
import random
import numpy as np
import torch

SEED = 1

# Seed Python, NumPy, Torch (CPU and CUDA)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Helpers for DataLoader reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

from pathlib import Path
from loguru import logger
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Subset
from src.datasets.seeg_dataset import SEEGDataset
from src.models.model import SEEGFusionModel, BaselineModel
from src.training.train import train_model
from src.training.evaluate import evaluate_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

In [None]:
def compute_class_weights(train_ds):
    labels = np.array([v[1] for v in train_ds])
    class_sample_count = np.array(
        [len(np.where(labels == t)[0]) for t in np.unique(labels)])
    weight = class_sample_count.sum() / class_sample_count
    return torch.from_numpy(weight).float()

# Helper to get indices for specific subjects
def get_subject_indices(dataset, subj_list):
    return [i for i, s in enumerate(dataset.data) if s['subject'] in subj_list]

# Create dataset once (loads all subjects)
subjects=['Epat31','Epat35','Epat37','Epat38','Spat31','Spat37']
full_dataset = SEEGDataset(subjects=subjects)

In [None]:
# Example: LOPO outer loop
model_type = 'Fusion'
metric_dict = {}
for test_subj in subjects:
    logger.info(f"\n=== Test subject: {test_subj} ===")
    remaining_subjs = [s for s in subjects if s != test_subj]

    # Outer split: test vs remaining
    test_idx = get_subject_indices(full_dataset, [test_subj])
    test_ds = Subset(full_dataset, test_idx)

    # Inner split subjects (for hyperparam tuning)
    # Shuffle remaining subjects so different folds vary
    random.shuffle(remaining_subjs)

    # Do 5 different inner splits (4 train / 1 val)
    inner_splits = []
    for i in range(5):
        # rotate subjects for different validation sets
        val_subjs = remaining_subjs[i]
        train_subjs = [s for s in remaining_subjs if s not in val_subjs]
        inner_splits.append((train_subjs, val_subjs))

    # Run inner CV for this test subject
    for k, (train_subjs, val_subjs) in enumerate([inner_splits[0]]):
        logger.info(f"\nInner split {k+1}: train={train_subjs}, val={val_subjs}")

        train_idx = get_subject_indices(full_dataset, train_subjs)
        val_idx = get_subject_indices(full_dataset, val_subjs)

        train_ds = Subset(full_dataset, train_idx)
        val_ds = Subset(full_dataset, val_idx)

        dataloaders = {
            'train': DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0, worker_init_fn=seed_worker, generator=g),
            'val': DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g),
            'test': DataLoader(test_ds, batch_size=4, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g)
        }

        weights = compute_class_weights(train_ds)

        if model_type == 'Fusion':
            model = SEEGFusionModel(embed_dim=128, n_classes=2, device=device)
        elif model_type == 'Baseline':
            model = BaselineModel(embed_dim=128, n_classes=2, device=device, stim_model='convergent', n_elecs=25, generator=g)
        model.to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-6)
        scheduler = optim.lr_scheduler.CyclicLR(
            optimizer,
            base_lr=1e-6,
            max_lr=1e-4,
            step_size_up=50,
            step_size_down=50,
            cycle_momentum=False
        )

        criterion = nn.CrossEntropyLoss(weight=weights.to(device))

        model, history, best_epoch = train_model(
            model=model,
            dataloaders=dataloaders,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            save_prefix=f'{test_subj}_model_{model_type}_split_{k}',
            n_epochs=10,
            patience=2,
        )

    metrics = evaluate_model(model, dataloaders['test'], device)
    metric_dict[test_subj] = metrics

In [None]:
def convert_to_final_dict(metric_dict):

  final_dict = {}

  for d in metric_dict.values():
    for k,v in d.items():
      final_dict.setdefault(k,[]).append(v)

  for k in final_dict:
    vals = np.array(final_dict[k])
    print(f'{k}:           {np.mean(vals):0.3f} +/- {np.std(vals):0.3f}')
  
  return final_dict

In [None]:
fusion_final_dict = convert_to_final_dict(metric_dict)

In [None]:
baseline_metric_dict = {}
experiments_dir = Path('../experiments')
for model_path in list(experiments_dir.glob('*model_Baseline_split_0_best_*.pt')):
  model = BaselineModel(embed_dim=128, n_classes=2, device=device, stim_model='convergent', n_elecs=25, generator=g)
  model.load_state_dict(torch.load(model_path, weights_only=True))
  model.to(device)
  test_subj = model_path.name.split('_')[0]

  test_idx = get_subject_indices(full_dataset, [test_subj])
  test_ds = Subset(full_dataset, test_idx)

  dataloader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g)
  metrics = evaluate_model(model, dataloader, device)
  baseline_metric_dict[test_subj] = metrics

baseline_final_dict = convert_to_final_dict(baseline_metric_dict)

In [None]:
from scipy.stats import wilcoxon

for val in ['auroc', 'f1', 'youden_index']:
    pval = wilcoxon(baseline_final_dict[val], fusion_final_dict[val]).pvalue
    print(f"{val}, p-value={pval}")