In [1]:
# Place this as the FIRST cell, before importing torch.
import random
import numpy as np
import torch

SEED = 1

# Seed Python, NumPy, Torch (CPU and CUDA)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Helpers for DataLoader reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED);

In [2]:
from loguru import logger
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Subset
from src.datasets.seeg_dataset import SEEGDataset
from src.models.model import SEEGFusionModel, BaselineModel
from src.training.train import train_model
from src.training.evaluate import evaluate_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cuda


In [3]:
def compute_class_weights(train_ds):
    labels = np.array([v[1] for v in train_ds])
    class_sample_count = np.array(
        [len(np.where(labels == t)[0]) for t in np.unique(labels)])
    weight = class_sample_count.sum() / class_sample_count
    return torch.from_numpy(weight).float()

# Helper to get indices for specific subjects
def get_subject_indices(dataset, subj_list):
    return [i for i, s in enumerate(dataset.data) if s['subject'] in subj_list]

# Create dataset once (loads all subjects)
subjects=['Epat31','Epat35','Epat37','Epat38','Spat31','Spat37']
full_dataset = SEEGDataset(subjects=subjects)

[32m2025-11-12 08:37:01.644[0m | [32m[1mSUCCESS [0m | [36msrc.datasets.seeg_dataset[0m:[36m__init__[0m:[36m128[0m - [32m[1m✅ Loaded 281 total samples from 6 subjects.[0m


In [4]:
# Example: LOPO outer loop
model_type = 'Fusion'
baseline_metric_dict = {}
for test_subj in subjects:
    logger.info(f"\n=== Test subject: {test_subj} ===")
    remaining_subjs = [s for s in subjects if s != test_subj]

    # Outer split: test vs remaining
    test_idx = get_subject_indices(full_dataset, [test_subj])
    test_ds = Subset(full_dataset, test_idx)

    # Inner split subjects (for hyperparam tuning)
    # Shuffle remaining subjects so different folds vary
    random.shuffle(remaining_subjs)

    # Do 5 different inner splits (4 train / 1 val)
    inner_splits = []
    for i in range(5):
        # rotate subjects for different validation sets
        val_subjs = remaining_subjs[i]
        train_subjs = [s for s in remaining_subjs if s not in val_subjs]
        inner_splits.append((train_subjs, val_subjs))

    # Run inner CV for this test subject
    for k, (train_subjs, val_subjs) in enumerate([inner_splits[0]]):
        logger.info(f"\nInner split {k+1}: train={train_subjs}, val={val_subjs}")

        train_idx = get_subject_indices(full_dataset, train_subjs)
        val_idx = get_subject_indices(full_dataset, val_subjs)

        train_ds = Subset(full_dataset, train_idx)
        val_ds = Subset(full_dataset, val_idx)

        dataloaders = {
            'train': DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0, worker_init_fn=seed_worker, generator=g),
            'val': DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g),
            'test': DataLoader(test_ds, batch_size=4, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g)
        }

        weights = compute_class_weights(train_ds)

        if model_type == 'Fusion':
            model = SEEGFusionModel(embed_dim=128, n_classes=2, device=device)
        elif model_type == 'Baseline':
            model = BaselineModel(embed_dim=128, n_classes=2, device=device, stim_model='convergent', n_elecs=25, generator=g)
        model.to(device)
        optimizer = optim.AdamW(model.parameters(), lr=1e-4)
        criterion = nn.CrossEntropyLoss(weight=weights.to(device))

        model, history, best_epoch = train_model(
            model=model,
            dataloaders=dataloaders,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            save_prefix=f'{test_subj}_model_{model_type}_split_{k}',
            n_epochs=20,
            patience=3,
        )

    metrics = evaluate_model(model, dataloaders['test'], device)
    baseline_metric_dict[test_subj] = metrics

[32m2025-11-12 08:37:03.048[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
=== Test subject: Epat31 ===[0m
[32m2025-11-12 08:37:03.049[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m
Inner split 1: train=['Spat31', 'Spat37', 'Epat35', 'Epat37'], val=Epat38[0m
[32m2025-11-12 08:37:03.561[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m48[0m - [1m
Starting training for 20 epochs on device: cuda[0m




[32m2025-11-12 08:37:22.360[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 1/20 Summary:[0m
[32m2025-11-12 08:37:22.360[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.6167 | Train Acc: 0.6995[0m
[32m2025-11-12 08:37:22.360[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5349 | Val Acc:   0.9130[0m
[32m2025-11-12 08:37:22.360[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 18.80s[0m




[32m2025-11-12 08:37:40.561[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 2/20 Summary:[0m
[32m2025-11-12 08:37:40.561[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.7109 | Train Acc: 0.8907[0m
[32m2025-11-12 08:37:40.561[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.6391 | Val Acc:   0.9130[0m
[32m2025-11-12 08:37:40.561[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 18.16s[0m




[32m2025-11-12 08:37:58.797[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 3/20 Summary:[0m
[32m2025-11-12 08:37:58.798[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.8113 | Train Acc: 0.8907[0m
[32m2025-11-12 08:37:58.798[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.7583 | Val Acc:   0.9130[0m
[32m2025-11-12 08:37:58.798[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 18.21s[0m




[32m2025-11-12 08:38:17.052[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 4/20 Summary:[0m
[32m2025-11-12 08:38:17.052[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.9029 | Train Acc: 0.8907[0m
[32m2025-11-12 08:38:17.053[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.7463 | Val Acc:   0.9130[0m
[32m2025-11-12 08:38:17.053[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 18.23s[0m
[32m2025-11-12 08:38:17.076[0m | [32m[1mSUCCESS [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m138[0m - [32m[1m⏹ Early stopping at epoch 4 (no val loss improvement for 3 epochs)[0m




[32m2025-11-12 08:38:17.750[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.9310344827586207), 'auroc': 0.9545454545454545, 'f1': 0.875, 'sensitivity': np.float64(1.0), 'specificity': np.float64(0.9090909090909091), 'youden_index': np.float64(0.9090909090909091), 'optimal_threshold': np.float32(0.039523985)}[0m
[32m2025-11-12 08:38:17.751[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
=== Test subject: Epat35 ===[0m
[32m2025-11-12 08:38:17.751[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m
Inner split 1: train=['Epat38', 'Epat37', 'Spat37', 'Spat31'], val=Epat31[0m
[32m2025-11-12 08:38:17.779[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m48[0m - [1m
Starting training for 20 epochs on device: cuda[0m




[32m2025-11-12 08:38:38.479[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 1/20 Summary:[0m
[32m2025-11-12 08:38:38.479[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.6819 | Train Acc: 0.8899[0m
[32m2025-11-12 08:38:38.479[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   1.5728 | Val Acc:   0.7586[0m
[32m2025-11-12 08:38:38.480[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 20.70s[0m




[32m2025-11-12 08:38:59.264[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 2/20 Summary:[0m
[32m2025-11-12 08:38:59.265[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.9275 | Train Acc: 0.8899[0m
[32m2025-11-12 08:38:59.265[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   1.8551 | Val Acc:   0.7586[0m
[32m2025-11-12 08:38:59.265[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 20.70s[0m




[32m2025-11-12 08:39:19.967[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 3/20 Summary:[0m
[32m2025-11-12 08:39:19.967[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.9793 | Train Acc: 0.8899[0m
[32m2025-11-12 08:39:19.967[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   1.6883 | Val Acc:   0.7586[0m
[32m2025-11-12 08:39:19.968[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 20.68s[0m




[32m2025-11-12 08:39:40.652[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 4/20 Summary:[0m
[32m2025-11-12 08:39:40.652[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.9354 | Train Acc: 0.8899[0m
[32m2025-11-12 08:39:40.652[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   1.6918 | Val Acc:   0.7586[0m
[32m2025-11-12 08:39:40.652[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 20.66s[0m
[32m2025-11-12 08:39:40.676[0m | [32m[1mSUCCESS [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m138[0m - [32m[1m⏹ Early stopping at epoch 4 (no val loss improvement for 3 epochs)[0m




[32m2025-11-12 08:39:41.574[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.7058823529411765), 'auroc': 0.8125, 'f1': 0.2857142857142857, 'sensitivity': np.float64(1.0), 'specificity': np.float64(0.6875), 'youden_index': np.float64(0.6875), 'optimal_threshold': np.float32(0.043753345)}[0m
[32m2025-11-12 08:39:41.575[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
=== Test subject: Epat37 ===[0m
[32m2025-11-12 08:39:41.575[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m
Inner split 1: train=['Spat31', 'Spat37', 'Epat31', 'Epat35'], val=Epat38[0m
[32m2025-11-12 08:39:41.603[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m48[0m - [1m
Starting training for 20 epochs on device: cuda[0m




[32m2025-11-12 08:39:54.867[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 1/20 Summary:[0m
[32m2025-11-12 08:39:54.867[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.6765 | Train Acc: 0.8543[0m
[32m2025-11-12 08:39:54.868[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5142 | Val Acc:   0.9130[0m
[32m2025-11-12 08:39:54.868[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 13.26s[0m




[32m2025-11-12 08:40:08.220[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 2/20 Summary:[0m
[32m2025-11-12 08:40:08.221[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.6878 | Train Acc: 0.8543[0m
[32m2025-11-12 08:40:08.221[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5629 | Val Acc:   0.9130[0m
[32m2025-11-12 08:40:08.221[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 13.21s[0m




[32m2025-11-12 08:40:21.501[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 3/20 Summary:[0m
[32m2025-11-12 08:40:21.501[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.9630 | Train Acc: 0.8543[0m
[32m2025-11-12 08:40:21.502[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5796 | Val Acc:   0.9130[0m
[32m2025-11-12 08:40:21.502[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 13.26s[0m




[32m2025-11-12 08:40:34.759[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 4/20 Summary:[0m
[32m2025-11-12 08:40:34.760[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.8804 | Train Acc: 0.8543[0m
[32m2025-11-12 08:40:34.760[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.6559 | Val Acc:   0.9130[0m
[32m2025-11-12 08:40:34.760[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 13.23s[0m
[32m2025-11-12 08:40:34.783[0m | [32m[1mSUCCESS [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m138[0m - [32m[1m⏹ Early stopping at epoch 4 (no val loss improvement for 3 epochs)[0m




[32m2025-11-12 08:40:37.580[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.45901639344262296), 'auroc': 0.6428571428571428, 'f1': 0.23255813953488372, 'sensitivity': np.float64(1.0), 'specificity': np.float64(0.4107142857142857), 'youden_index': np.float64(0.4107142857142857), 'optimal_threshold': np.float32(0.047815494)}[0m
[32m2025-11-12 08:40:37.580[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
=== Test subject: Epat38 ===[0m
[32m2025-11-12 08:40:37.580[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m
Inner split 1: train=['Epat31', 'Epat37', 'Spat37', 'Spat31'], val=Epat35[0m
[32m2025-11-12 08:40:37.609[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m48[0m - [1m
Starting training for 20 epochs on device: cuda[0m




[32m2025-11-12 08:40:53.433[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 1/20 Summary:[0m
[32m2025-11-12 08:40:53.434[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.7511 | Train Acc: 0.7697[0m
[32m2025-11-12 08:40:53.434[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.4394 | Val Acc:   0.9412[0m
[32m2025-11-12 08:40:53.434[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.82s[0m




[32m2025-11-12 08:41:09.400[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 2/20 Summary:[0m
[32m2025-11-12 08:41:09.400[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.7056 | Train Acc: 0.8596[0m
[32m2025-11-12 08:41:09.401[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.4364 | Val Acc:   0.9412[0m
[32m2025-11-12 08:41:09.401[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.90s[0m




[32m2025-11-12 08:41:25.324[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 3/20 Summary:[0m
[32m2025-11-12 08:41:25.325[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.7277 | Train Acc: 0.8596[0m
[32m2025-11-12 08:41:25.325[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.4479 | Val Acc:   0.9412[0m
[32m2025-11-12 08:41:25.325[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.86s[0m




[32m2025-11-12 08:41:41.210[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 4/20 Summary:[0m
[32m2025-11-12 08:41:41.210[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.7866 | Train Acc: 0.8596[0m
[32m2025-11-12 08:41:41.210[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5160 | Val Acc:   0.9412[0m
[32m2025-11-12 08:41:41.210[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.86s[0m




[32m2025-11-12 08:41:57.116[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 5/20 Summary:[0m
[32m2025-11-12 08:41:57.117[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.8325 | Train Acc: 0.8596[0m
[32m2025-11-12 08:41:57.117[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5360 | Val Acc:   0.9412[0m
[32m2025-11-12 08:41:57.118[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.88s[0m
[32m2025-11-12 08:41:57.141[0m | [32m[1mSUCCESS [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m138[0m - [32m[1m⏹ Early stopping at epoch 5 (no val loss improvement for 3 epochs)[0m




[32m2025-11-12 08:41:59.834[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.6231884057971014), 'auroc': 0.7089947089947091, 'f1': 0.3157894736842105, 'sensitivity': np.float64(1.0), 'specificity': np.float64(0.5873015873015873), 'youden_index': np.float64(0.5873015873015873), 'optimal_threshold': np.float32(0.036762375)}[0m
[32m2025-11-12 08:41:59.835[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
=== Test subject: Spat31 ===[0m
[32m2025-11-12 08:41:59.835[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m
Inner split 1: train=['Epat31', 'Spat37', 'Epat37', 'Epat38'], val=Epat35[0m
[32m2025-11-12 08:41:59.867[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m48[0m - [1m
Starting training for 20 epochs on device: cuda[0m




[32m2025-11-12 08:42:20.036[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 1/20 Summary:[0m
[32m2025-11-12 08:42:20.036[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.5411 | Train Acc: 0.8373[0m
[32m2025-11-12 08:42:20.037[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.4977 | Val Acc:   0.9412[0m
[32m2025-11-12 08:42:20.037[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 20.17s[0m




[32m2025-11-12 08:42:40.227[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 2/20 Summary:[0m
[32m2025-11-12 08:42:40.227[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.7668 | Train Acc: 0.9043[0m
[32m2025-11-12 08:42:40.227[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.6616 | Val Acc:   0.9412[0m
[32m2025-11-12 08:42:40.228[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 20.13s[0m




[32m2025-11-12 08:43:00.370[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 3/20 Summary:[0m
[32m2025-11-12 08:43:00.370[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.9052 | Train Acc: 0.9043[0m
[32m2025-11-12 08:43:00.371[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.6486 | Val Acc:   0.9412[0m
[32m2025-11-12 08:43:00.371[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 20.12s[0m




[32m2025-11-12 08:43:20.491[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 4/20 Summary:[0m
[32m2025-11-12 08:43:20.491[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.9174 | Train Acc: 0.9043[0m
[32m2025-11-12 08:43:20.492[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.6323 | Val Acc:   0.9412[0m
[32m2025-11-12 08:43:20.492[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 20.10s[0m
[32m2025-11-12 08:43:20.516[0m | [32m[1mSUCCESS [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m138[0m - [32m[1m⏹ Early stopping at epoch 4 (no val loss improvement for 3 epochs)[0m




[32m2025-11-12 08:43:21.521[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.7105263157894737), 'auroc': 0.7946127946127945, 'f1': 0.6451612903225806, 'sensitivity': np.float64(0.9090909090909091), 'specificity': np.float64(0.6296296296296297), 'youden_index': np.float64(0.5387205387205387), 'optimal_threshold': np.float32(0.030621873)}[0m
[32m2025-11-12 08:43:21.522[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
=== Test subject: Spat37 ===[0m
[32m2025-11-12 08:43:21.522[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m
Inner split 1: train=['Epat38', 'Epat35', 'Epat31', 'Spat31'], val=Epat37[0m
[32m2025-11-12 08:43:21.551[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m48[0m - [1m
Starting training for 20 epochs on device: cuda[0m




[32m2025-11-12 08:43:36.828[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 1/20 Summary:[0m
[32m2025-11-12 08:43:36.828[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.6789 | Train Acc: 0.8176[0m
[32m2025-11-12 08:43:36.829[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5163 | Val Acc:   0.9180[0m
[32m2025-11-12 08:43:36.829[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.28s[0m




[32m2025-11-12 08:43:52.149[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 2/20 Summary:[0m
[32m2025-11-12 08:43:52.150[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.7093 | Train Acc: 0.8471[0m
[32m2025-11-12 08:43:52.150[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5477 | Val Acc:   0.9180[0m
[32m2025-11-12 08:43:52.150[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.25s[0m




[32m2025-11-12 08:44:07.451[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 3/20 Summary:[0m
[32m2025-11-12 08:44:07.452[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.8110 | Train Acc: 0.8471[0m
[32m2025-11-12 08:44:07.452[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5362 | Val Acc:   0.9180[0m
[32m2025-11-12 08:44:07.452[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.28s[0m




[32m2025-11-12 08:44:22.764[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m113[0m - [1m
Epoch 4/20 Summary:[0m
[32m2025-11-12 08:44:22.765[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m114[0m - [1m  Train Loss: 0.7498 | Train Acc: 0.8471[0m
[32m2025-11-12 08:44:22.765[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m115[0m - [1m  Val Loss:   0.5390 | Val Acc:   0.9180[0m
[32m2025-11-12 08:44:22.765[0m | [1mINFO    [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m116[0m - [1m  Time: 15.25s[0m
[32m2025-11-12 08:44:22.789[0m | [32m[1mSUCCESS [0m | [36msrc.training.train[0m:[36mtrain_model[0m:[36m138[0m - [32m[1m⏹ Early stopping at epoch 4 (no val loss improvement for 3 epochs)[0m




[32m2025-11-12 08:44:24.667[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.78), 'auroc': 0.7708333333333334, 'f1': 0.26666666666666666, 'sensitivity': np.float64(1.0), 'specificity': np.float64(0.7708333333333334), 'youden_index': np.float64(0.7708333333333334), 'optimal_threshold': np.float32(0.106995106)}[0m


In [None]:
# subset_loader = torch.utils.data.DataLoader(
#     torch.utils.data.Subset(full_dataset, range(16)), 
#     batch_size=16, shuffle=True
# )
# weights = compute_class_weights(subset_loader)
# # model = SEEGFusionModel(embed_dim=128, n_classes=2, device=device)
# model = BaselineModel(embed_dim=128, n_classes=2, device=device, stim_model='convergent', n_elecs=30)
# optimizer = optim.AdamW(model.parameters(), lr=1e-4)
# criterion = nn.CrossEntropyLoss(weight=weights.to(device))

# model, history, best_epoch = train_model(model, {'train':subset_loader, 'val': subset_loader}, criterion, optimizer, device, save_prefix='testing', n_epochs=1)

In [None]:
baseline_final_dict = {}
# baseline_final_dict['val_loss'] = [
#     0.6856,
#     0.5122,
#     0.5740,
#     0.4049,
#     0.5611,
#     0.6310
# ]

for d in baseline_metric_dict.values():
  for k,v in d.items():
    baseline_final_dict.setdefault(k,[]).append(v)

for k in baseline_final_dict:
  vals = np.array(baseline_final_dict[k])
  print(f'{k}:           {np.mean(vals):0.3f} +/- {np.std(vals):0.3f}')

In [6]:
fusion_final_dict = {}
fusion_final_dict['auroc'] = [
    0.8961, 
    0.8438,
    0.5786,
    0.6243,
    0.7811,
    0.6875
]
fusion_final_dict['f1'] = [
    0.7,
    0.4444444,
    0.2609,
    0.2703,
    0.6452,
    0.2
]
fusion_final_dict['youden_index'] = [
    0.727272,
    0.84375,
    0.3321,
    0.4206,
    0.5387,
    0.666666
]
fusion_final_dict['val_loss'] = [
    0.6968,
    0.5951,
    0.4687,
    0.5319,
    0.4780,
    0.6317
]

for k in fusion_final_dict:
  vals = np.array(fusion_final_dict[k])
  print(f'{k}:           {np.mean(vals):0.3f} +/- {np.std(vals):0.3f}')

auroc:           0.735 +/- 0.115
f1:           0.420 +/- 0.194
youden_index:           0.588 +/- 0.177
val_loss:           0.567 +/- 0.082


In [None]:
from scipy.stats import wilcoxon

In [None]:
for val in ['auroc', 'f1', 'youden_index']:
    pval = wilcoxon(baseline_final_dict[val], fusion_final_dict[val]).pvalue
    print(f"{val}, p-value={pval}")

In [5]:
from pathlib import Path

fusion_metric_dict = {}
experiments_dir = Path('../experiments')
for model_path in list(experiments_dir.glob('*model_Fusion_split_0_best_*.pt')):
  model = SEEGFusionModel(embed_dim=128, n_classes=2, device=device)
  model.load_state_dict(torch.load(model_path, weights_only=True))
  model.to(device)
  test_subj = model_path.name.split('_')[0]

  test_idx = get_subject_indices(full_dataset, [test_subj])
  test_ds = Subset(full_dataset, test_idx)

  dataloader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=0, worker_init_fn=seed_worker, generator=g)
  metrics = evaluate_model(model, dataloader, device)
  fusion_metric_dict[test_subj] = metrics

final_fusion_dict = {}
for d in fusion_metric_dict.values():
  for k,v in d.items():
    final_fusion_dict.setdefault(k,[]).append(v)

for k in final_fusion_dict:
  vals = np.array(final_fusion_dict[k])
  print(f'{k}:           {np.mean(vals):0.3f} +/- {np.std(vals):0.3f}')

[32m2025-11-12 08:45:24.176[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.66), 'auroc': 0.6666666666666666, 'f1': 0.19047619047619047, 'sensitivity': np.float64(1.0), 'specificity': np.float64(0.6458333333333333), 'youden_index': np.float64(0.6458333333333333), 'optimal_threshold': np.float32(0.45557716)}[0m
[32m2025-11-12 08:45:26.896[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.8115942028985508), 'auroc': 0.8756613756613757, 'f1': 0.48, 'sensitivity': np.float64(1.0), 'specificity': np.float64(0.7936507936507937), 'youden_index': np.float64(0.7936507936507937), 'optimal_threshold': np.float32(0.46679524)}[0m
[32m2025-11-12 08:45:27.836[0m | [1mINFO    [0m | [36msrc.training.evaluate[0m:[36mevaluate_model[0m:[36m69[0m - [1m{'accuracy': np.float64(0.38235294117647056), 'auroc': 0.515625, 'f1': 0.16, 'sensitivity': np

accuracy:           0.707 +/- 0.158
auroc:           0.751 +/- 0.146
f1:           0.429 +/- 0.216
sensitivity:           0.906 +/- 0.141
specificity:           0.694 +/- 0.166
youden_index:           0.600 +/- 0.177
optimal_threshold:           0.445 +/- 0.060
