In [9]:
# Ensure repo root (containing package dir src/) is on sys.path for imports
import sys
from pathlib import Path
try:
    REPO_ROOT  # defined in earlier cells
except NameError:
    # Fallback: assume this notebook lives in <repo>/notebooks
    REPO_ROOT = Path.cwd().resolve().parent
# We need the parent directory so `import src...` works
repo_root_str = str(REPO_ROOT)
if repo_root_str not in sys.path:
    sys.path.insert(0, repo_root_str)
print('sys.path[0] =', sys.path[0])
print('Contains src?:', (REPO_ROOT / 'src').is_dir())


sys.path[0] = /workspace/Kaggle-HMS
Contains src?: True


# Diagnose Fast Training & Data Processing
This notebook verifies whether the pipeline is unintentionally using a reduced subset or debug settings causing unrealistically fast epochs.

In [1]:
# 1) Verify Environment and Devices
import os, sys, json, subprocess, platform, time
import torch, psutil
from pathlib import Path
print('Python:', platform.python_version())
print('Torch:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())
print('CUDA device count:', torch.cuda.device_count())
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        p = torch.cuda.get_device_properties(i)
        print(f'GPU {i}: {p.name}, {p.total_memory/1024**3:.1f} GB VRAM')
print('CPU cores:', psutil.cpu_count(logical=True))
print('RAM:', psutil.virtual_memory().total/1024**3, 'GB')

Python: 3.11.14
Torch: 2.5.1+cu121
CUDA available: True
CUDA device count: 1
GPU 0: NVIDIA H200 NVL, 139.8 GB VRAM
CPU cores: 512
RAM: 1511.1955833435059 GB


In [2]:
# 2) Dataset Volume Sanity Checks
import pandas as pd, numpy as np
from pathlib import Path
NOTEBOOK_CWD = Path.cwd()
REPO_ROOT = NOTEBOOK_CWD.parent if NOTEBOOK_CWD.name == 'notebooks' else NOTEBOOK_CWD
raw_root = REPO_ROOT / 'data' / 'raw'
assert raw_root.exists(), f'Missing raw data root {raw_root}'
train_csv = raw_root/'train_unique.csv'
assert train_csv.exists(), 'train_unique.csv missing – check data mount'
df_train = pd.read_csv(train_csv)
n_rows = len(df_train)
print('train_unique rows:', n_rows)
# Count parquet EEG files
train_eegs = list((raw_root/'train_eegs').glob('*.parquet'))
test_eegs  = list((raw_root/'test_eegs').glob('*.parquet'))
print('Train EEG parquet files:', len(train_eegs))
print('Test EEG parquet files:', len(test_eegs))
# Approx expected label windows
label_cols = [c for c in df_train.columns if c.endswith('_vote')]
print('Vote columns:', label_cols)
print('Expert consensus present:', 'expert_consensus' in df_train.columns)
# Basic checksum of first few EEG files (size only for speed)
sizes = [f.stat().st_size for f in train_eegs[:10]]
print('Sample EEG file sizes (bytes):', sizes)
print('Total EEG size (first 10):', sum(sizes))

train_unique rows: 20183
Train EEG parquet files: 17300
Test EEG parquet files: 1
Vote columns: ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
Expert consensus present: True
Sample EEG file sizes (bytes): [980197, 1194256, 417427, 3104172, 3355117, 517347, 1357640, 1125394, 422150, 5846691]
Total EEG size (first 10): 18320391


In [None]:
# 3) Verify Full Dataset Paths and Debug Flags Off
from omegaconf import OmegaConf
train_cfg_path = Path('configs/train_4fold.yaml')
mlp_cfg_path = Path('configs/training_mlp.yaml')
train_cfg = OmegaConf.load(str(train_cfg_path))
mlp_cfg = OmegaConf.load(str(mlp_cfg_path))
debug_indicators = []
for cfg_name, cfg in [('graph', train_cfg), ('mlp', mlp_cfg)]:
    txt = OmegaConf.to_yaml(cfg)
    if any(k in txt for k in ['smoke_test','fast_dev_run','limit_train_batches']):
        debug_indicators.append(cfg_name)
print('Config debug indicators:', debug_indicators)
print('Graph smoke_test:', getattr(train_cfg,'smoke_test', False))
print('MLP trainer fast_dev_run available?', hasattr(mlp_cfg,'trainer') and getattr(mlp_cfg.trainer,'fast_dev_run', False))

In [None]:
# 4) Schema and Sample Validation
required_cols = {'label_id','patient_id','eeg_id','eeg_label_offset_seconds', *label_cols}
missing = required_cols - set(df_train.columns)
print('Missing columns:', missing)
print(df_train.head(3))

In [None]:
# 5) Deterministic Seeding and Data Order Checks
import random, numpy as np, hashlib
SEED=42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
def hash_ids(ids):
    m = hashlib.sha256()
    for s in ids:
        m.update(str(s).encode())
    return m.hexdigest()
first_50 = df_train['label_id'].head(50).tolist()
print('Head label_id hash:', hash_ids(first_50))

In [None]:
# 6) DataLoader and Coverage Assertions (Baseline Raw EEG)
from omegaconf import OmegaConf
from src.data.raw_datamodule import RawEEGDataModule
mlp_cfg = OmegaConf.load('configs/training_mlp.yaml')
mlp_cfg.data.splits.split_ratios = [0.7,0.15,0.15]
mlp_dm = RawEEGDataModule(mlp_cfg)
mlp_dm.prepare_data()
mlp_dm.setup('fit')
train_loader = mlp_dm.train_dataloader()
val_loader = mlp_dm.val_dataloader()
print('Train batches:', len(train_loader), 'Val batches:', len(val_loader))
unique_ids = set()
for batch in train_loader:
    unique_ids.update(batch['label_ids'])
print('Unique train label_ids covered:', len(unique_ids))
print('Train dataset size:', len(train_loader.dataset))

In [None]:
# 7) Throughput and Timing Instrumentation (baseline single batch)
import time
baseline_batch = next(iter(train_loader))
start = time.time(); eeg = baseline_batch['eeg_signal'].to('cpu'); cpu_load = time.time()-start
start = time.time(); eeg_cuda = eeg.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')); h2d = time.time()-start
print(f'Host load {cpu_load*1000:.1f} ms, H2D {h2d*1000:.1f} ms, batch shape {eeg_cuda.shape}')

In [None]:
# 8) Preprocessing and Caching Verification
from importlib import reload
import src.data.raw_eeg_dataset as red
reload(red)
ds = mlp_dm.train_dataset
item0 = ds[0]
sig0 = item0['eeg_signal'].clone()
# Re-extract to simulate transform; compare hash
item0b = ds[0]
sig0b = item0b['eeg_signal']
same = torch.allclose(sig0, sig0b)
print('Signal deterministic across calls:', bool(same))

In [None]:
# 9) Training Loop Accounting (1 tiny epoch)
from src.lightning_trainer.mlp_lightning_module import EEGMLPLightningModule
import pytorch_lightning as pl
cfg = mlp_cfg.copy()
cfg.trainer.max_epochs = 1
cfg.loss.type = 'kl'
logger = False
trainer = pl.Trainer(max_epochs=1, accelerator='cpu', devices=1, logger=logger, enable_progress_bar=True, fast_dev_run=False)
model = EEGMLPLightningModule(cfg)
start=time.time()
trainer.fit(model, train_loader, val_loader)
elapsed = time.time()-start
print('One epoch elapsed seconds:', elapsed)
print('global_step:', trainer.global_step)

In [None]:
# 10) Early Stopping & Fast-Dev-Run Checks (graph config)
print('Graph max epochs:', getattr(train_cfg,'num_epochs', None))
print('Graph early stopping:', dict(train_cfg.early_stopping))
print('Graph smoke_test:', getattr(train_cfg,'smoke_test', False))

In [None]:
# 11) Multiprocessing / Distributed Config
print('MLP num_workers:', mlp_cfg.data.loader.num_workers)
print('Graph num_workers:', train_cfg.data.num_workers)
print('Potential distributed flags present?', any(k in OmegaConf.to_yaml(train_cfg) for k in ['strategy','ddp']))

In [None]:
# 12) GPU/CPU Utilization & Precision
print('Mixed precision in graph cfg:', getattr(getattr(train_cfg,'hardware',None),'mixed_precision', None) or getattr(train_cfg,'mixed_precision', None))
print('Torch matmul precision:', torch.get_float32_matmul_precision())

In [None]:
# 13) Unit Tests for Counts and Steps (inline quick asserts)
steps_per_epoch = len(train_loader)
expected = int(np.ceil(len(train_loader.dataset)/mlp_cfg.data.loader.batch_size))
print('Steps/epoch (observed/expected):', steps_per_epoch, expected)
assert steps_per_epoch in (expected, max(1, expected-1)), 'Unexpected steps per epoch (drop_last?)'
# Non-empty batches
assert all(batch['eeg_signal'].shape[0] > 0 for batch in [baseline_batch]), 'Empty batch encountered'
print('Basic unit checks: OK')

In [None]:
# 14) Overfit-on-Tiny-Subset Sanity Check
tiny_indices = list(range(64))
from torch.utils.data import Subset
tiny_loader = torch.utils.data.DataLoader(Subset(train_loader.dataset, tiny_indices), batch_size=32, shuffle=True)
model_tiny = EEGMLPLightningModule(cfg)
opt = torch.optim.Adam(model_tiny.parameters(), lr=1e-2)
for ep in range(3):
    losses=[]
    for b in tiny_loader:
        x = b['eeg_signal']
        y = b['target']
        opt.zero_grad()
        logits = model_tiny(x)
        logp = torch.log_softmax(logits, dim=-1)
        loss = torch.nn.functional.kl_div(logp, y, reduction='batchmean')
        loss.backward(); opt.step()
        losses.append(loss.item())
    print('Tiny subset epoch', ep+1, 'loss', sum(losses)/len(losses))

In [None]:
# 15) Structured Logging to File
log_dir = Path('logs/diagnostics'); log_dir.mkdir(parents=True, exist_ok=True)
with open(log_dir/'env.json','w') as f:
    json.dump({'python':'ok','torch':torch.__version__}, f)
print('Saved logs to', log_dir)

In [None]:
# 16) Validation/Test Coverage Assertions
mlp_dm.setup('test')
test_loader = mlp_dm.test_dataloader()
test_ids=set()
for b in test_loader:
    test_ids.update(b['label_ids'])
print('Validation size:', len(val_loader.dataset),' Test size:', len(test_loader.dataset))
print('Unique test label_ids:', len(test_ids))

In [3]:
# Full Experiment Orchestrator – toggles

RUN_BASELINE = True   # set False to skip baseline training from notebook
RUN_GRAPH = True      # enable graph training folds from notebook (can be long)

BASELINE_N_SPLITS = 5
GRAPH_FOLDS = [0,1,2,3]

# Use project-relative path (repo root is parent of this notebooks dir)
BASELINE_CONFIG = '../configs/training_mlp.yaml'

GRAPH_CONFIG_SRC = '../configs/train_4fold.yaml'
GRAPH_CONFIG_NOTEBOOK = '../configs/train_4fold_notebook.yaml'

WANDB_PROJECT_BASELINE = None  # use value from config if None
WANDB_PROJECT_GRAPH = None     # use value from config if None

In [24]:
# Holdout test configuration (patient-level split)
HOLDOUT_TEST_FRAC = 0.20  # 20% of patients reserved for test-only
HOLDOUT_SEED = 42
HOLDOUT_TEST_CSV = '../data/raw/test_holdout.csv'
HOLDOUT_TRAIN_CSV = '../data/raw/train_unique_excl_holdout.csv'


In [12]:
# Baseline Cross-Validation Training (Raw EEG) - with path & env fixes
import math, copy, time, pandas as pd, numpy as np, torch, sys
from pathlib import Path
# Ensure repository root and src are on sys.path
NOTEBOOK_CWD = Path.cwd()
if NOTEBOOK_CWD.name == 'notebooks':
    REPO_ROOT = NOTEBOOK_CWD.parent
else:
    REPO_ROOT = NOTEBOOK_CWD
SRC_DIR = REPO_ROOT / 'src'
RAW_ROOT = REPO_ROOT / 'data' / 'raw'
for p in (REPO_ROOT, SRC_DIR):
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))

import pytorch_lightning as pl
from omegaconf import OmegaConf
from sklearn.model_selection import GroupKFold
try:
    from src.data.raw_datamodule import RawEEGDataModule
    from src.lightning_trainer.mlp_lightning_module import EEGMLPLightningModule
except ModuleNotFoundError as e:
    print('Import failed after path injection. sys.path:')
    for p in sys.path: print(' ', p)
    raise e

class LiveBaselineCallback(pl.Callback):
    def __init__(self):
        self.start_time = None
    def _ensure_start(self):
        if self.start_time is None:
            self.start_time = time.time()
    def on_train_start(self, trainer, pl_module):
        self.start_time = time.time()
    def on_validation_start(self, trainer, pl_module):
        # sanity check can trigger validation before train starts
        self._ensure_start()
    def on_validation_epoch_end(self, trainer, pl_module):
        self._ensure_start()
        elapsed = time.time() - self.start_time
        val_loss = trainer.callback_metrics.get('val/loss_epoch') or trainer.callback_metrics.get('val/loss')
        acc = trainer.callback_metrics.get('val/acc_macro')
        f1 = trainer.callback_metrics.get('val/f1_macro')
        if val_loss is not None and acc is not None and f1 is not None:
            print(f"[Baseline] Epoch {trainer.current_epoch+1}/{trainer.max_epochs} elapsed={elapsed/60:.2f}m val_loss={float(val_loss):.4f} acc={float(acc):.3f} f1={float(f1):.3f}")

if RUN_BASELINE:
    baseline_cfg = OmegaConf.load(BASELINE_CONFIG)
    # Fix relative paths when running from notebooks
    if 'data' in baseline_cfg:
        # If we created a train-excl-holdout CSV, prefer it for training
        excl_csv = RAW_ROOT / 'train_unique_excl_holdout.csv'
        if excl_csv.exists():
            baseline_cfg.data.metadata_csv = str(excl_csv)
        elif 'metadata_csv' in baseline_cfg.data:
            baseline_cfg.data.metadata_csv = str(REPO_ROOT / baseline_cfg.data.metadata_csv)
        if 'raw_eeg' in baseline_cfg.data and 'base_dir' in baseline_cfg.data.raw_eeg:
            baseline_cfg.data.raw_eeg.base_dir = str(REPO_ROOT / baseline_cfg.data.raw_eeg.base_dir)
    if 'checkpointing' in baseline_cfg and 'dirpath' in baseline_cfg.checkpointing:
        baseline_cfg.checkpointing.dirpath = str(REPO_ROOT / baseline_cfg.checkpointing.dirpath)
    # Adapt accelerator for Linux/CUDA environment
    if torch.cuda.is_available():
        baseline_cfg.trainer.accelerator = 'cuda'
        baseline_cfg.trainer.devices = 1
    else:
        baseline_cfg.trainer.accelerator = 'cpu'
        baseline_cfg.trainer.devices = 1
    # Soft vs hard labels toggle
    USE_SOFT = True  # toggle to False for consensus hard labels
    if USE_SOFT:
        baseline_cfg.data.target_mode = 'votes'
    else:
        baseline_cfg.data.target_mode = 'consensus'
        baseline_cfg.data.raw_eeg.label_to_index = {k:i for i,k in enumerate(['Seizure','LPD','GPD','LRDA','GRDA','Other'])}
    pl.seed_everything(int(getattr(baseline_cfg,'seed',42)), workers=True)
    dm = RawEEGDataModule(baseline_cfg)
    dm.prepare_data(); dm.setup('fit')
    patient_ids = np.array(dm.patient_ids())
    gkf = GroupKFold(n_splits=BASELINE_N_SPLITS)
    fold_results = []
    for fold,(train_idx,val_idx) in enumerate(gkf.split(patient_ids, groups=patient_ids)):
        print(f"\n=== Baseline Fold {fold+1}/{BASELINE_N_SPLITS} ===")
        train_ids = [str(pid) for pid in patient_ids[train_idx]]
        val_ids   = [str(pid) for pid in patient_ids[val_idx]]
        train_ds = dm.dataset_for_patients(train_ids)
        val_ds   = dm.dataset_for_patients(val_ids)
        train_loader = dm.dataloader(train_ds, shuffle=True)
        val_loader   = dm.dataloader(val_ds, shuffle=False)
        model = EEGMLPLightningModule(baseline_cfg)
        callbacks=[LiveBaselineCallback()]
        if hasattr(baseline_cfg,'checkpointing'):
            from pytorch_lightning.callbacks import ModelCheckpoint
            ck = baseline_cfg.checkpointing
            fold_dir = Path(str(ck.dirpath)) / f"notebook_fold_{fold}"
            fold_dir.mkdir(parents=True, exist_ok=True)
            callbacks.append(ModelCheckpoint(dirpath=str(fold_dir), monitor=ck.monitor, mode=ck.mode, save_top_k=int(ck.save_top_k), save_last=bool(ck.save_last)))
        trainer = pl.Trainer(max_epochs=int(baseline_cfg.trainer.max_epochs), accelerator=baseline_cfg.trainer.accelerator, devices=baseline_cfg.trainer.devices, precision=baseline_cfg.trainer.precision, gradient_clip_val=baseline_cfg.trainer.gradient_clip_val, log_every_n_steps=baseline_cfg.trainer.log_every_n_steps, callbacks=callbacks, enable_progress_bar=True, benchmark=getattr(baseline_cfg.trainer,'benchmark',True))
        trainer.fit(model, train_loader, val_loader)
        val_loss = trainer.callback_metrics.get('val/loss_epoch') or trainer.callback_metrics.get('val/loss')
        acc = trainer.callback_metrics.get('val/acc_macro')
        f1 = trainer.callback_metrics.get('val/f1_macro')
        fold_results.append({'fold':fold,'val_loss':float(val_loss) if val_loss else None,'acc_macro':float(acc) if acc else None,'f1_macro':float(f1) if f1 else None})
    baseline_results_df = pd.DataFrame(fold_results)
    display(baseline_results_df)
else:
    print('Baseline training skipped (RUN_BASELINE=False).')

Seed set to 42



=== Baseline Fold 1/5 ===


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | EEGMLPBaseline   | 34.2 K | train
1 | train_metrics | MetricCollection | 0      | train
2 | val_metrics   | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
34.2 K    Trainable params
0         Non-trainable params
34.2 K    Total params
0.137     Total estimated model params size (MB)
20        Modules in train mode

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


[Baseline] Epoch 1/30 elapsed=0.00m val_loss=1.2790 acc=0.184 f1=0.124


/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 1/30 elapsed=1.62m val_loss=1.2940 acc=0.181 f1=0.085


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 2/30 elapsed=1.91m val_loss=1.2883 acc=0.183 f1=0.086


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 3/30 elapsed=2.19m val_loss=1.2950 acc=0.183 f1=0.111


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 4/30 elapsed=2.46m val_loss=1.3084 acc=0.193 f1=0.147


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 5/30 elapsed=2.74m val_loss=1.3137 acc=0.212 f1=0.165


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 6/30 elapsed=3.01m val_loss=1.3253 acc=0.206 f1=0.166


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 7/30 elapsed=3.29m val_loss=1.3745 acc=0.209 f1=0.171


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 8/30 elapsed=3.57m val_loss=1.3507 acc=0.217 f1=0.185


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 9/30 elapsed=3.84m val_loss=1.3840 acc=0.215 f1=0.184


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 10/30 elapsed=4.10m val_loss=1.4395 acc=0.211 f1=0.183


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 11/30 elapsed=4.37m val_loss=1.4274 acc=0.219 f1=0.185


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 12/30 elapsed=4.64m val_loss=1.4983 acc=0.213 f1=0.179


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 13/30 elapsed=4.92m val_loss=1.4851 acc=0.221 f1=0.194


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 14/30 elapsed=5.18m val_loss=1.5066 acc=0.217 f1=0.190


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 15/30 elapsed=5.46m val_loss=1.5423 acc=0.221 f1=0.195


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 16/30 elapsed=5.72m val_loss=1.5707 acc=0.218 f1=0.196


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 17/30 elapsed=5.99m val_loss=1.6046 acc=0.214 f1=0.189


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 18/30 elapsed=6.26m val_loss=1.6124 acc=0.224 f1=0.199


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 19/30 elapsed=6.52m val_loss=1.6518 acc=0.214 f1=0.189


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 20/30 elapsed=6.79m val_loss=1.6542 acc=0.228 f1=0.203


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 21/30 elapsed=7.06m val_loss=1.6664 acc=0.221 f1=0.197


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 22/30 elapsed=7.33m val_loss=1.6978 acc=0.215 f1=0.190


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 23/30 elapsed=7.61m val_loss=1.6813 acc=0.223 f1=0.200


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 24/30 elapsed=7.88m val_loss=1.6925 acc=0.219 f1=0.196


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 25/30 elapsed=8.16m val_loss=1.7048 acc=0.222 f1=0.198


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 26/30 elapsed=8.44m val_loss=1.7180 acc=0.227 f1=0.205


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 27/30 elapsed=8.71m val_loss=1.7277 acc=0.226 f1=0.204


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 28/30 elapsed=8.98m val_loss=1.7310 acc=0.227 f1=0.203


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 29/30 elapsed=9.25m val_loss=1.7324 acc=0.230 f1=0.207


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 30/30 elapsed=9.52m val_loss=1.7422 acc=0.227 f1=0.204


`Trainer.fit` stopped: `max_epochs=30` reached.



=== Baseline Fold 2/5 ===


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | EEGMLPBaseline   | 34.2 K | train
1 | train_metrics | MetricCollection | 0      | train
2 | val_metrics   | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
34.2 K    Trainable params
0         Non-trainable params
34.2 K    Total params
0.137     Total estimated model params size (MB)
20        Modules in train mode

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


[Baseline] Epoch 1/30 elapsed=0.00m val_loss=1.2725 acc=0.188 f1=0.103


/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 1/30 elapsed=1.53m val_loss=1.2477 acc=0.183 f1=0.097


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 2/30 elapsed=1.80m val_loss=1.2427 acc=0.186 f1=0.104


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 3/30 elapsed=2.07m val_loss=1.2375 acc=0.189 f1=0.113


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 4/30 elapsed=2.35m val_loss=1.2392 acc=0.210 f1=0.165


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 5/30 elapsed=2.62m val_loss=1.2529 acc=0.220 f1=0.183


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 6/30 elapsed=2.90m val_loss=1.2625 acc=0.231 f1=0.196


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 7/30 elapsed=3.17m val_loss=1.3039 acc=0.226 f1=0.193


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 8/30 elapsed=3.46m val_loss=1.3075 acc=0.231 f1=0.198


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 9/30 elapsed=3.73m val_loss=1.3304 acc=0.244 f1=0.214


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 10/30 elapsed=4.00m val_loss=1.3177 acc=0.244 f1=0.215


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 11/30 elapsed=4.29m val_loss=1.3689 acc=0.242 f1=0.208


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 12/30 elapsed=4.56m val_loss=1.3830 acc=0.244 f1=0.218


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 13/30 elapsed=4.83m val_loss=1.4048 acc=0.253 f1=0.229


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 14/30 elapsed=5.10m val_loss=1.4334 acc=0.247 f1=0.224


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 15/30 elapsed=5.37m val_loss=1.4446 acc=0.248 f1=0.219


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 16/30 elapsed=5.64m val_loss=1.4824 acc=0.254 f1=0.227


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 17/30 elapsed=5.92m val_loss=1.4794 acc=0.246 f1=0.217


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 18/30 elapsed=6.19m val_loss=1.5055 acc=0.239 f1=0.216


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 19/30 elapsed=6.47m val_loss=1.5292 acc=0.264 f1=0.238


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 20/30 elapsed=6.75m val_loss=1.5411 acc=0.245 f1=0.219


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 21/30 elapsed=7.03m val_loss=1.5373 acc=0.247 f1=0.227


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 22/30 elapsed=7.30m val_loss=1.5767 acc=0.245 f1=0.224


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 23/30 elapsed=7.59m val_loss=1.5831 acc=0.241 f1=0.220


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 24/30 elapsed=7.86m val_loss=1.5891 acc=0.247 f1=0.226


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 25/30 elapsed=8.13m val_loss=1.5984 acc=0.245 f1=0.220


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 26/30 elapsed=8.41m val_loss=1.6100 acc=0.248 f1=0.225


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 27/30 elapsed=8.69m val_loss=1.6243 acc=0.245 f1=0.220


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 28/30 elapsed=8.96m val_loss=1.6194 acc=0.247 f1=0.225


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 29/30 elapsed=9.23m val_loss=1.6208 acc=0.249 f1=0.224


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


[Baseline] Epoch 30/30 elapsed=9.51m val_loss=1.6295 acc=0.246 f1=0.222

=== Baseline Fold 3/5 ===


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | EEGMLPBaseline   | 34.2 K | train
1 | train_metrics | MetricCollection | 0      | train
2 | val_metrics   | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
34.2 K    Trainable params
0         Non-trainable params
34.2 K    Total params
0.137     Total estimated model params size (MB)
20        Modules in train mode

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


[Baseline] Epoch 1/30 elapsed=0.00m val_loss=1.5632 acc=0.232 f1=0.107


/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 1/30 elapsed=1.62m val_loss=1.2820 acc=0.182 f1=0.105


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 2/30 elapsed=1.91m val_loss=1.2780 acc=0.183 f1=0.110


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 3/30 elapsed=2.18m val_loss=1.2863 acc=0.200 f1=0.151


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 4/30 elapsed=2.45m val_loss=1.2838 acc=0.200 f1=0.157


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 5/30 elapsed=2.72m val_loss=1.3013 acc=0.223 f1=0.195


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 6/30 elapsed=3.00m val_loss=1.3204 acc=0.216 f1=0.185


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 7/30 elapsed=3.27m val_loss=1.3379 acc=0.216 f1=0.188


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 8/30 elapsed=3.53m val_loss=1.3426 acc=0.224 f1=0.198


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 9/30 elapsed=3.80m val_loss=1.3498 acc=0.223 f1=0.194


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 10/30 elapsed=4.07m val_loss=1.3979 acc=0.224 f1=0.201


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 11/30 elapsed=4.34m val_loss=1.3942 acc=0.224 f1=0.201


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 12/30 elapsed=4.61m val_loss=1.4297 acc=0.227 f1=0.207


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 13/30 elapsed=4.88m val_loss=1.4398 acc=0.229 f1=0.209


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 14/30 elapsed=5.15m val_loss=1.4458 acc=0.232 f1=0.213


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 15/30 elapsed=5.42m val_loss=1.4460 acc=0.233 f1=0.210


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 16/30 elapsed=5.69m val_loss=1.4939 acc=0.233 f1=0.208


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 17/30 elapsed=5.97m val_loss=1.5160 acc=0.235 f1=0.218


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 18/30 elapsed=6.24m val_loss=1.4932 acc=0.231 f1=0.213


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 19/30 elapsed=6.51m val_loss=1.5291 acc=0.234 f1=0.215


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 20/30 elapsed=6.79m val_loss=1.5377 acc=0.234 f1=0.220


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 21/30 elapsed=7.06m val_loss=1.5564 acc=0.226 f1=0.214


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 22/30 elapsed=7.33m val_loss=1.5660 acc=0.243 f1=0.227


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 23/30 elapsed=7.61m val_loss=1.5652 acc=0.245 f1=0.229


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 24/30 elapsed=7.88m val_loss=1.5865 acc=0.233 f1=0.218


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 25/30 elapsed=8.15m val_loss=1.5875 acc=0.238 f1=0.220


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 26/30 elapsed=8.43m val_loss=1.6036 acc=0.237 f1=0.219


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 27/30 elapsed=8.71m val_loss=1.6057 acc=0.240 f1=0.223


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 28/30 elapsed=8.99m val_loss=1.6103 acc=0.238 f1=0.222


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 29/30 elapsed=9.28m val_loss=1.6160 acc=0.238 f1=0.222


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


[Baseline] Epoch 30/30 elapsed=9.56m val_loss=1.6201 acc=0.240 f1=0.224

=== Baseline Fold 4/5 ===


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | EEGMLPBaseline   | 34.2 K | train
1 | train_metrics | MetricCollection | 0      | train
2 | val_metrics   | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
34.2 K    Trainable params
0         Non-trainable params
34.2 K    Total params
0.137     Total estimated model params size (MB)
20        Modules in train mode

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


[Baseline] Epoch 1/30 elapsed=0.00m val_loss=1.5370 acc=0.218 f1=0.127


/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 1/30 elapsed=1.54m val_loss=1.2370 acc=0.175 f1=0.109


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 2/30 elapsed=1.81m val_loss=1.2273 acc=0.176 f1=0.113


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 3/30 elapsed=2.09m val_loss=1.2313 acc=0.198 f1=0.158


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 4/30 elapsed=2.36m val_loss=1.2209 acc=0.214 f1=0.182


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 5/30 elapsed=2.63m val_loss=1.2355 acc=0.221 f1=0.195


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 6/30 elapsed=2.89m val_loss=1.2633 acc=0.229 f1=0.203


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 7/30 elapsed=3.16m val_loss=1.2486 acc=0.241 f1=0.224


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 8/30 elapsed=3.43m val_loss=1.2656 acc=0.233 f1=0.217


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 9/30 elapsed=3.71m val_loss=1.2815 acc=0.245 f1=0.221


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 10/30 elapsed=3.98m val_loss=1.3040 acc=0.241 f1=0.222


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 11/30 elapsed=4.27m val_loss=1.3237 acc=0.240 f1=0.223


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 12/30 elapsed=4.54m val_loss=1.3347 acc=0.246 f1=0.225


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 13/30 elapsed=4.81m val_loss=1.3700 acc=0.245 f1=0.228


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 14/30 elapsed=5.09m val_loss=1.4106 acc=0.231 f1=0.217


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 15/30 elapsed=5.37m val_loss=1.4077 acc=0.250 f1=0.229


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 16/30 elapsed=5.64m val_loss=1.4234 acc=0.257 f1=0.236


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 17/30 elapsed=5.91m val_loss=1.4172 acc=0.252 f1=0.234


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 18/30 elapsed=6.18m val_loss=1.4404 acc=0.240 f1=0.227


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 19/30 elapsed=6.45m val_loss=1.4735 acc=0.268 f1=0.251


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 20/30 elapsed=6.72m val_loss=1.4451 acc=0.255 f1=0.237


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 21/30 elapsed=7.00m val_loss=1.4766 acc=0.254 f1=0.239


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 22/30 elapsed=7.28m val_loss=1.4875 acc=0.257 f1=0.241


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 23/30 elapsed=7.55m val_loss=1.5091 acc=0.248 f1=0.234


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 24/30 elapsed=7.83m val_loss=1.5040 acc=0.250 f1=0.236


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 25/30 elapsed=8.10m val_loss=1.5299 acc=0.244 f1=0.230


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 26/30 elapsed=8.37m val_loss=1.5288 acc=0.250 f1=0.235


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 27/30 elapsed=8.65m val_loss=1.5343 acc=0.247 f1=0.231


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 28/30 elapsed=8.93m val_loss=1.5354 acc=0.250 f1=0.233


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 29/30 elapsed=9.21m val_loss=1.5339 acc=0.249 f1=0.232


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


[Baseline] Epoch 30/30 elapsed=9.49m val_loss=1.5362 acc=0.254 f1=0.236

=== Baseline Fold 5/5 ===


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | EEGMLPBaseline   | 34.2 K | train
1 | train_metrics | MetricCollection | 0      | train
2 | val_metrics   | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
34.2 K    Trainable params
0         Non-trainable params
34.2 K    Total params
0.137     Total estimated model params size (MB)
20        Modules in train mode

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


[Baseline] Epoch 1/30 elapsed=0.00m val_loss=1.5795 acc=0.142 f1=0.026


/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 1/30 elapsed=1.00m val_loss=1.2860 acc=0.174 f1=0.089


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 2/30 elapsed=1.29m val_loss=1.2790 acc=0.177 f1=0.096


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 3/30 elapsed=1.56m val_loss=1.2748 acc=0.190 f1=0.125


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 4/30 elapsed=1.83m val_loss=1.2740 acc=0.200 f1=0.137


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 5/30 elapsed=2.11m val_loss=1.2750 acc=0.206 f1=0.154


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 6/30 elapsed=2.40m val_loss=1.2955 acc=0.205 f1=0.163


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 7/30 elapsed=2.66m val_loss=1.2931 acc=0.225 f1=0.194


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 8/30 elapsed=2.93m val_loss=1.3077 acc=0.228 f1=0.196


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 9/30 elapsed=3.20m val_loss=1.3229 acc=0.228 f1=0.200


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 10/30 elapsed=3.48m val_loss=1.3400 acc=0.231 f1=0.205


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 11/30 elapsed=3.75m val_loss=1.3810 acc=0.236 f1=0.212


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 12/30 elapsed=4.03m val_loss=1.4031 acc=0.234 f1=0.211


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 13/30 elapsed=4.31m val_loss=1.4103 acc=0.249 f1=0.228


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 14/30 elapsed=4.58m val_loss=1.4380 acc=0.232 f1=0.211


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 15/30 elapsed=4.85m val_loss=1.4908 acc=0.251 f1=0.227


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 16/30 elapsed=5.13m val_loss=1.4874 acc=0.238 f1=0.217


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 17/30 elapsed=5.41m val_loss=1.5184 acc=0.236 f1=0.220


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 18/30 elapsed=5.68m val_loss=1.5242 acc=0.239 f1=0.218


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 19/30 elapsed=5.95m val_loss=1.5526 acc=0.244 f1=0.228


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 20/30 elapsed=6.23m val_loss=1.5443 acc=0.243 f1=0.225


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 21/30 elapsed=6.50m val_loss=1.5732 acc=0.248 f1=0.230


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 22/30 elapsed=6.77m val_loss=1.6041 acc=0.233 f1=0.217


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 23/30 elapsed=7.04m val_loss=1.5935 acc=0.239 f1=0.217


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 24/30 elapsed=7.32m val_loss=1.5965 acc=0.242 f1=0.223


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 25/30 elapsed=7.60m val_loss=1.6192 acc=0.246 f1=0.226


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 26/30 elapsed=7.87m val_loss=1.6262 acc=0.241 f1=0.224


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 27/30 elapsed=8.15m val_loss=1.6227 acc=0.241 f1=0.223


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 28/30 elapsed=8.42m val_loss=1.6298 acc=0.238 f1=0.218


Validation: |          | 0/? [00:00<?, ?it/s]

[Baseline] Epoch 29/30 elapsed=8.69m val_loss=1.6381 acc=0.244 f1=0.224


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


[Baseline] Epoch 30/30 elapsed=8.97m val_loss=1.6377 acc=0.240 f1=0.221


Unnamed: 0,fold,val_loss,acc_macro,f1_macro
0,0,1.742204,0.227072,0.203808
1,1,1.629546,0.246142,0.221545
2,2,1.620068,0.2396,0.224147
3,3,1.536242,0.253637,0.236389
4,4,1.637727,0.240466,0.220667


In [13]:
# Graph Model Cross-Validation Training (with path fixes)
import shutil, time
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pathlib import Path
from src.train import train as graph_train

class LiveGraphCallback(pl.Callback):
    def __init__(self):
        self.start=None
    def on_train_start(self, trainer, pl_module):
        self.start=time.time()
    def on_validation_epoch_end(self, trainer, pl_module):
        if self.start is None:
            self.start=time.time()
        elapsed=time.time()-self.start
        vloss=trainer.callback_metrics.get('val/loss_epoch') or trainer.callback_metrics.get('val/loss')
        if vloss is not None:
            print(f"[Graph] Epoch {trainer.current_epoch+1}/{trainer.max_epochs} elapsed={elapsed/60:.2f}m val_loss={float(vloss):.4f}")

if RUN_GRAPH:
    # Copy original config to notebook-specific path (avoid overwriting source)
    if Path(GRAPH_CONFIG_NOTEBOOK).exists():
        Path(GRAPH_CONFIG_NOTEBOOK).unlink()
    shutil.copy(GRAPH_CONFIG_SRC, GRAPH_CONFIG_NOTEBOOK)
    fold_metrics=[]
    for f in GRAPH_FOLDS:
        print(f"\n=== Graph Fold {f} ===")
        cfg = OmegaConf.load(GRAPH_CONFIG_NOTEBOOK)
        # Fix relative paths to project root
        REPO_ROOT = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()
        if 'data' in cfg:
            if 'data_dir' in cfg.data:
                cfg.data.data_dir = str(REPO_ROOT / cfg.data.data_dir)
            if 'train_csv' in cfg.data:
                cfg.data.train_csv = str(REPO_ROOT / cfg.data.train_csv)
        cfg.data.current_fold = f
        # Persist patched config for this fold
        OmegaConf.save(cfg, GRAPH_CONFIG_NOTEBOOK)
        # Run training (train() prints its own summaries)
        trainer, model, dm = graph_train(train_config_path=GRAPH_CONFIG_NOTEBOOK,
                                         model_config_path=None,
                                         wandb_project=cfg.get('wandb_project','hms-brain-activity'),
                                         wandb_name=f"notebook-fold{f}")
        vloss = trainer.callback_metrics.get('val/loss_epoch') or trainer.callback_metrics.get('val/loss')
        fold_metrics.append({'fold':f,'val_loss':float(vloss) if vloss else None})
    import pandas as pd
    display(pd.DataFrame(fold_metrics))
else:
    print('Graph training skipped (RUN_GRAPH=False).')


=== Graph Fold 0 ===

HMS Multi-Modal GNN Training
Model Type: multi_modal
Model Config: /workspace/Kaggle-HMS/configs/model.yaml
Train Config: ../configs/train_4fold_notebook.yaml
WandB Project: hms-brain-activity-KL
WandB Run: notebook-fold0
Fold: 0/3

Initializing DataModule...
Note: preload_patients=True → forcing num_workers=0 to avoid shared-memory mmap pressure.
Stratifying by evaluator bins only


Preloading patients: 100%|██████████| 1459/1459 [04:31<00:00,  5.37it/s]
Preloading patients: 100%|██████████| 491/491 [01:15<00:00,  6.49it/s]



Dataset Setup - Fold 0/3:
  Train: 1459 patients, 16132 samples
  Val:   491 patients, 4051 samples

  Stratification by evaluator bins:
    Train: {0: 10965, 1: 212, 2: 3646, 3: 1234, 4: 75}
    Val: {0: 2726, 1: 47, 2: 909, 3: 349, 4: 20}

  Class weights: [0.6017259359359741, 0.6564667224884033, 1.2223418951034546, 1.9777363538742065, 1.240186333656311, 0.3015425503253937]

Initializing Model...

Model Architecture:
  Eeg Output Dim       256
  Spec Output Dim      256
  Fusion Output Dim    512
  Num Classes          6
  Total Params         2,019,080
  Trainable Params     2,019,080



[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msserkanakin[0m ([33mgraph-ml-project[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | HMSMultiModalGNN | 2.0 M  | train
1 | criterion     | KLDivLoss        | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.076     Total estimated 

Starting training...



/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved. New best score: 1.630
Epoch 0, global step 31: 'val/loss' reached 1.62971 (best 1.62971), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.6297.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.056 >= min_delta = 0.0. New best score: 1.574
Epoch 0, global step 62: 'val/loss' reached 1.57409 (best 1.57409), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5741.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.061 >= min_delta = 0.0. New best score: 1.513
Epoch 0, global step 93: 'val/loss' reached 1.51294 (best 1.51294), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5129.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.049 >= min_delta = 0.0. New best score: 1.464
Epoch 0, global step 124: 'val/loss' reached 1.46352 (best 1.46352), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.4635.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 158: 'val/loss' reached 1.50518 (best 1.46352), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.5052.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.019 >= min_delta = 0.0. New best score: 1.445
Epoch 1, global step 189: 'val/loss' reached 1.44480 (best 1.44480), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.4448.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 220: 'val/loss' reached 1.47258 (best 1.44480), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.4726.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 251: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 285: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 316: 'val/loss' reached 1.45421 (best 1.44480), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.4542.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.014 >= min_delta = 0.0. New best score: 1.430
Epoch 2, global step 347: 'val/loss' reached 1.43035 (best 1.43035), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.4304.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 378: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 412: 'val/loss' reached 1.43919 (best 1.43035), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=03-val/loss=1.4392.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 443: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 474: 'val/loss' reached 1.43349 (best 1.43035), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=03-val/loss=1.4335.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 505: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 539: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.013 >= min_delta = 0.0. New best score: 1.417
Epoch 4, global step 570: 'val/loss' reached 1.41706 (best 1.41706), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.4171.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 601: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 632: 'val/loss' reached 1.43277 (best 1.41706), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.4328.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 666: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 697: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 728: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 759: 'val/loss' reached 1.42938 (best 1.41706), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=05-val/loss=1.4294.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 793: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 824: 'val/loss' reached 1.42484 (best 1.41706), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.4248.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.000 >= min_delta = 0.0. New best score: 1.417
Epoch 6, global step 855: 'val/loss' reached 1.41661 (best 1.41661), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.4166.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 886: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 920: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 951: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 982: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 1013: 'val/loss' reached 1.42208 (best 1.41661), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=07-val/loss=1.4221.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1047: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.013 >= min_delta = 0.0. New best score: 1.404
Epoch 8, global step 1078: 'val/loss' reached 1.40370 (best 1.40370), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=08-val/loss=1.4037.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1109: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1140: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1174: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1205: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1236: 'val/loss' reached 1.41517 (best 1.40370), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=09-val/loss=1.4152.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1267: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 1301: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 1332: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 1363: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.007 >= min_delta = 0.0. New best score: 1.397
Epoch 10, global step 1394: 'val/loss' reached 1.39712 (best 1.39712), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=10-val/loss=1.3971.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.001 >= min_delta = 0.0. New best score: 1.396
Epoch 11, global step 1428: 'val/loss' reached 1.39621 (best 1.39621), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=11-val/loss=1.3962.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 1459: 'val/loss' reached 1.39714 (best 1.39621), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=11-val/loss=1.3971.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 1490: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 1521: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 1555: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 1586: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 1617: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 1648: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 1682: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 1713: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val/loss_epoch did not improve in the last 10 records. Best score: 1.396. Signaling Trainer to stop.
Epoch 13, global step 1744: 'val/loss' was not in top 3



Testing best model...

Stratifying by evaluator bins only


Restoring states from the checkpoint path at /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=11-val/loss=1.3962.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=11-val/loss=1.3962.ckpt
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/acc            0.4653172194957733
     test/loss_epoch        1.3962048292160034
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Training Complete!
Best checkpoint: /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=11-val/loss=1.3962.ckpt
WandB run: https://wandb.ai/graph-ml-project/hms-brain-activity-KL/runs/txwuewbl?apiKey=83eeac404a7e1eaa760f387613ede32c2f8b0c4f


=== Graph Fold 1 ===

HMS Multi-Modal GNN Training
Model Type: multi_modal
Model Config: /workspace/Kaggle-HMS/configs/model.yaml
Train Config: ../configs/train_4fold_notebook.yaml
WandB Project: hms-brain-activity-KL
WandB Run: notebook-fold1
Fold: 1/3

Preloading patients: 100%|██████████| 1463/1463 [04:39<00:00,  5.23it/s]
Preloading patients: 100%|██████████| 487/487 [01:22<00:00,  5.90it/s]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /workspace/Kaggle-HMS/notebooks/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model 


Dataset Setup - Fold 1/3:
  Train: 1463 patients, 15376 samples
  Val:   487 patients, 4807 samples

  Stratification by evaluator bins:
    Train: {0: 10533, 1: 185, 2: 3443, 3: 1152, 4: 63}
    Val: {0: 3158, 1: 74, 2: 1112, 3: 431, 4: 32}

  Class weights: [0.5678414106369019, 0.6723662614822388, 1.2046376466751099, 1.919015884399414, 1.3145467042922974, 0.3215923607349396]

Initializing Model...

Model Architecture:
  Eeg Output Dim       256
  Spec Output Dim      256
  Fusion Output Dim    512
  Num Classes          6
  Total Params         2,019,080
  Trainable Params     2,019,080

Starting training...



/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved. New best score: 1.509
Epoch 0, global step 30: 'val/loss' reached 1.50889 (best 1.50889), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5089.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.061 >= min_delta = 0.0. New best score: 1.448
Epoch 0, global step 60: 'val/loss' reached 1.44798 (best 1.44798), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.4480.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 90: 'val/loss' reached 1.46399 (best 1.44798), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.4640.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 120: 'val/loss' reached 1.45289 (best 1.44798), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.4529.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 151: 'val/loss' reached 1.44828 (best 1.44798), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.4483.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.015 >= min_delta = 0.0. New best score: 1.433
Epoch 1, global step 181: 'val/loss' reached 1.43276 (best 1.43276), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.4328.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.038 >= min_delta = 0.0. New best score: 1.395
Epoch 1, global step 211: 'val/loss' reached 1.39489 (best 1.39489), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.3949.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 241: 'val/loss' reached 1.40125 (best 1.39489), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.4013.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 272: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.004 >= min_delta = 0.0. New best score: 1.390
Epoch 2, global step 302: 'val/loss' reached 1.39040 (best 1.39040), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.3904.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 332: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.015 >= min_delta = 0.0. New best score: 1.375
Epoch 2, global step 362: 'val/loss' reached 1.37520 (best 1.37520), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.3752.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 393: 'val/loss' reached 1.37975 (best 1.37520), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=03-val/loss=1.3798.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 423: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.002 >= min_delta = 0.0. New best score: 1.373
Epoch 3, global step 453: 'val/loss' reached 1.37272 (best 1.37272), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=03-val/loss=1.3727.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 483: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 514: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 544: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 574: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 604: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.001 >= min_delta = 0.0. New best score: 1.372
Epoch 5, global step 635: 'val/loss' reached 1.37192 (best 1.37192), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=05-val/loss=1.3719.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 665: 'val/loss' reached 1.37253 (best 1.37192), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=05-val/loss=1.3725.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 695: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 725: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 756: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 786: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.015 >= min_delta = 0.0. New best score: 1.357
Epoch 6, global step 816: 'val/loss' reached 1.35732 (best 1.35732), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.3573.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 846: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 877: 'val/loss' reached 1.36162 (best 1.35732), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=07-val/loss=1.3616.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 907: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 937: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 967: 'val/loss' reached 1.36392 (best 1.35732), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=07-val/loss=1.3639.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 998: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1028: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1058: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1088: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val/loss_epoch did not improve in the last 10 records. Best score: 1.357. Signaling Trainer to stop.
Epoch 9, global step 1119: 'val/loss' was not in top 3



Testing best model...

Stratifying by evaluator bins only


Restoring states from the checkpoint path at /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.3573.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.3573.ckpt
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/acc            0.49760764837265015
     test/loss_epoch        1.3573429584503174
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Training Complete!
Best checkpoint: /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.3573.ckpt
WandB run: https://wandb.ai/graph-ml-project/hms-brain-activity-KL/runs/txwuewbl?apiKey=83eeac404a7e1eaa760f387613ede32c2f8b0c4f


=== Graph Fold 2 ===

HMS Multi-Modal GNN Training
Model Type: multi_modal
Model Config: /workspace/Kaggle-HMS/configs/model.yaml
Train Config: ../configs/train_4fold_notebook.yaml
WandB Project: hms-brain-activity-KL
WandB Run: notebook-fold2
Fold: 2/


Preloading patients:  38%|███▊      | 552/1463 [01:46<01:20, 11.28it/s]IOStream.flush timed out
Preloading patients: 100%|██████████| 1463/1463 [04:18<00:00,  5.66it/s]
Preloading patients: 100%|██████████| 487/487 [01:30<00:00,  5.36it/s]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /workspace/Kaggle-HMS/notebooks/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/mo


Dataset Setup - Fold 2/3:
  Train: 1463 patients, 13826 samples
  Val:   487 patients, 6357 samples

  Stratification by evaluator bins:
    Train: {0: 9265, 1: 188, 2: 3140, 3: 1160, 4: 73}
    Val: {0: 4426, 1: 71, 2: 1415, 3: 423, 4: 22}

  Class weights: [0.5786168575286865, 0.740247905254364, 0.9352887272834778, 2.454005002975464, 1.0068542957305908, 0.2849871516227722]

Initializing Model...

Model Architecture:
  Eeg Output Dim       256
  Spec Output Dim      256
  Fusion Output Dim    512
  Num Classes          6
  Total Params         2,019,080
  Trainable Params     2,019,080

Starting training...



/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved. New best score: 1.596
Epoch 0, global step 27: 'val/loss' reached 1.59589 (best 1.59589), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5959.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.009 >= min_delta = 0.0. New best score: 1.587
Epoch 0, global step 54: 'val/loss' reached 1.58691 (best 1.58691), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5869.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 81: 'val/loss' reached 1.66845 (best 1.58691), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.6684.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.016 >= min_delta = 0.0. New best score: 1.571
Epoch 0, global step 108: 'val/loss' reached 1.57113 (best 1.57113), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5711.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 136: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 163: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.004 >= min_delta = 0.0. New best score: 1.567
Epoch 1, global step 190: 'val/loss' reached 1.56724 (best 1.56724), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.5672.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.027 >= min_delta = 0.0. New best score: 1.540
Epoch 1, global step 217: 'val/loss' reached 1.54039 (best 1.54039), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.5404.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 245: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 272: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.004 >= min_delta = 0.0. New best score: 1.536
Epoch 2, global step 299: 'val/loss' reached 1.53617 (best 1.53617), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.5362.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.004 >= min_delta = 0.0. New best score: 1.532
Epoch 2, global step 326: 'val/loss' reached 1.53240 (best 1.53240), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.5324.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.006 >= min_delta = 0.0. New best score: 1.526
Epoch 3, global step 354: 'val/loss' reached 1.52628 (best 1.52628), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=03-val/loss=1.5263.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 381: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 408: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.002 >= min_delta = 0.0. New best score: 1.524
Epoch 3, global step 435: 'val/loss' reached 1.52434 (best 1.52434), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=03-val/loss=1.5243.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 463: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 490: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 517: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.002 >= min_delta = 0.0. New best score: 1.522
Epoch 4, global step 544: 'val/loss' reached 1.52240 (best 1.52240), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.5224.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 572: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 599: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 626: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 653: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 681: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.000 >= min_delta = 0.0. New best score: 1.522
Epoch 6, global step 708: 'val/loss' reached 1.52205 (best 1.52205), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.5221.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 735: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 762: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 790: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 817: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 844: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 871: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 899: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 926: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 953: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val/loss_epoch did not improve in the last 10 records. Best score: 1.522. Signaling Trainer to stop.
Epoch 8, global step 980: 'val/loss' was not in top 3



Testing best model...

Stratifying by evaluator bins only


Restoring states from the checkpoint path at /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.5221.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.5221.ckpt
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/acc            0.3847726881504059
     test/loss_epoch         1.522058129310608
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Training Complete!
Best checkpoint: /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=06-val/loss=1.5221.ckpt
WandB run: https://wandb.ai/graph-ml-project/hms-brain-activity-KL/runs/txwuewbl?apiKey=83eeac404a7e1eaa760f387613ede32c2f8b0c4f


=== Graph Fold 3 ===

HMS Multi-Modal GNN Training
Model Type: multi_modal
Model Config: /workspace/Kaggle-HMS/configs/model.yaml
Train Config: ../configs/train_4fold_notebook.yaml
WandB Project: hms-brain-activity-KL
WandB Run: notebook-fold3
Fold: 3/3

Preloading patients: 100%|██████████| 1465/1465 [04:53<00:00,  4.99it/s]
Preloading patients: 100%|██████████| 485/485 [02:10<00:00,  3.72it/s]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /workspace/Kaggle-HMS/notebooks/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model 


Dataset Setup - Fold 3/3:
  Train: 1465 patients, 15215 samples
  Val:   485 patients, 4968 samples

  Stratification by evaluator bins:
    Train: {0: 10310, 1: 192, 2: 3436, 3: 1203, 4: 74}
    Val: {0: 3381, 1: 67, 2: 1119, 3: 380, 4: 21}

  Class weights: [0.5921630859375, 0.7224341630935669, 1.1196582317352295, 1.9971253871917725, 1.2723388671875, 0.29628053307533264]

Initializing Model...

Model Architecture:
  Eeg Output Dim       256
  Spec Output Dim      256
  Fusion Output Dim    512
  Num Classes          6
  Total Params         2,019,080
  Trainable Params     2,019,080

Starting training...



/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved. New best score: 1.618
Epoch 0, global step 29: 'val/loss' reached 1.61799 (best 1.61799), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.6180.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.034 >= min_delta = 0.0. New best score: 1.584
Epoch 0, global step 58: 'val/loss' reached 1.58427 (best 1.58427), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5843.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.009 >= min_delta = 0.0. New best score: 1.576
Epoch 0, global step 87: 'val/loss' reached 1.57556 (best 1.57556), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5756.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 116: 'val/loss' reached 1.58638 (best 1.57556), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=00-val/loss=1.5864.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 148: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.025 >= min_delta = 0.0. New best score: 1.550
Epoch 1, global step 177: 'val/loss' reached 1.55021 (best 1.55021), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.5502.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 206: 'val/loss' reached 1.56005 (best 1.55021), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.5600.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.021 >= min_delta = 0.0. New best score: 1.529
Epoch 1, global step 235: 'val/loss' reached 1.52881 (best 1.52881), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=01-val/loss=1.5288.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.020 >= min_delta = 0.0. New best score: 1.509
Epoch 2, global step 267: 'val/loss' reached 1.50854 (best 1.50854), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.5085.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 296: 'val/loss' reached 1.51174 (best 1.50854), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.5117.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 325: 'val/loss' reached 1.51806 (best 1.50854), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=02-val/loss=1.5181.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 354: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 386: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.008 >= min_delta = 0.0. New best score: 1.500
Epoch 3, global step 415: 'val/loss' reached 1.50041 (best 1.50041), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=03-val/loss=1.5004.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 444: 'val/loss' reached 1.50840 (best 1.50041), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=03-val/loss=1.5084.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 473: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.006 >= min_delta = 0.0. New best score: 1.495
Epoch 4, global step 505: 'val/loss' reached 1.49490 (best 1.49490), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.4949.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss_epoch improved by 0.009 >= min_delta = 0.0. New best score: 1.486
Epoch 4, global step 534: 'val/loss' reached 1.48579 (best 1.48579), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.4858.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 563: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 592: 'val/loss' reached 1.49195 (best 1.48579), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.4920.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 624: 'val/loss' reached 1.49415 (best 1.48579), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=05-val/loss=1.4942.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 653: 'val/loss' reached 1.49240 (best 1.48579), saving model to '/workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=05-val/loss=1.4924.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 682: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 711: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 743: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 772: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 801: 'val/loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val/loss_epoch did not improve in the last 10 records. Best score: 1.486. Signaling Trainer to stop.
Epoch 6, global step 830: 'val/loss' was not in top 3



Testing best model...

Stratifying by evaluator bins only


Restoring states from the checkpoint path at /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.4858.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.4858.ckpt
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/acc            0.4142512083053589
     test/loss_epoch         1.485782265663147
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Training Complete!
Best checkpoint: /workspace/Kaggle-HMS/notebooks/checkpoints/hms-epoch=04-val/loss=1.4858.ckpt
WandB run: https://wandb.ai/graph-ml-project/hms-brain-activity-KL/runs/txwuewbl?apiKey=83eeac404a7e1eaa760f387613ede32c2f8b0c4f



Unnamed: 0,fold,val_loss
0,0,
1,1,
2,2,
3,3,


In [None]:
# Test Set vs Validation Clarification
#
# In this notebook, any metrics labeled with prefix 'test/' for the graph model come from
# calling `Trainer.test(datamodule=dm, ckpt_path=...)`. For the HMS project DataModule,
# the `test_dataloader()` is intentionally the validation split (a held-out fold) so that
# we can standardize evaluation without accessing the private Kaggle test set (labels are
# not available). Thus:
#   - Baseline manual evaluation loops directly over each fold's validation partition.
#   - Graph model evaluation reuses Lightning's `test()` which internally runs on the
#     validation split for that fold.
# Therefore, all reported per-fold losses and accuracies are validation metrics (not
# unseen competition test data). We *do not* evaluate on the unlabeled Kaggle test set here.
#
# Below, we also extend graph inference to locate per-fold checkpoints in the project
# `checkpoints/` directory (pattern: `online-fold{f}-epoch=epoch=...`) in addition to WandB
# run artifact directories, selecting the lowest-loss checkpoint per fold when multiple
# are present.

In [16]:
# Inference & Validation Analysis (Baseline + Graph)
# - Loads best checkpoints per fold
# - Evaluates on the corresponding validation split (reused as test for graph)
# - Aggregates val loss, accuracy, F1, ECE, confusion matrix for baseline; loss & acc for graph

# NOTE: See preceding cell for explanation that 'test' metrics here are actually per-fold
# validation metrics (we do not use unlabeled Kaggle competition test set).

# --------------------------------------------------
import os, re, json, glob, math, time
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf

# Ensure imports for project modules
from pathlib import Path as _P
NOTEBOOK_CWD = _P.cwd()
REPO_ROOT = NOTEBOOK_CWD.parent if NOTEBOOK_CWD.name == 'notebooks' else NOTEBOOK_CWD
import sys
for _p in (REPO_ROOT, REPO_ROOT/'src'):
    if str(_p) not in sys.path:
        sys.path.insert(0, str(_p))

from src.data.raw_datamodule import RawEEGDataModule
from src.lightning_trainer.mlp_lightning_module import EEGMLPLightningModule
from src.data.graph_datamodule import HMSDataModule
from src.lightning_trainer import HMSLightningModule, HMSEEGOnlyLightningModule

# -----------------------------
# Helpers
# -----------------------------
def ece_score(probs: np.ndarray, labels: np.ndarray, n_bins: int = 15) -> float:
    conf = probs.max(axis=1)
    preds = probs.argmax(axis=1)
    correct = (preds == labels).astype(np.float32)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        msk = (conf > bins[i]) & (conf <= bins[i+1]) if i < n_bins-1 else (conf > bins[i]) & (conf <= bins[i+1])
        if not np.any(msk):
            continue
        acc_bin = correct[msk].mean()
        conf_bin = conf[msk].mean()
        ece += (msk.mean()) * abs(acc_bin - conf_bin)
    return float(ece)

def macro_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int = 6) -> float:
    f1s = []
    for c in range(n_classes):
        tp = np.sum((y_true == c) & (y_pred == c))
        fp = np.sum((y_true != c) & (y_pred == c))
        fn = np.sum((y_true == c) & (y_pred != c))
        if tp + fp == 0 or tp + fn == 0:
            f1s.append(0.0)
            continue
        precision = tp / max(1, (tp + fp))
        recall = tp / max(1, (tp + fn))
        if precision + recall == 0:
            f1s.append(0.0)
        else:
            f1s.append(2 * precision * recall / (precision + recall))
    return float(np.mean(f1s))

def confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int = 6) -> np.ndarray:
    cm = np.zeros((n_classes, n_classes), dtype=np.int64)
    for t, p in zip(y_true, y_pred):
        cm[int(t), int(p)] += 1
    return cm

def parse_loss_from_name(path: Path) -> float | None:
    m = re.search(r"loss=([0-9]+\.[0-9]+)\.ckpt$", str(path))
    if not m:
        m = re.search(r"loss_epoch=([0-9]+\.[0-9]+)\.ckpt$", str(path))
    return float(m.group(1)) if m else None

# --------------------------------------------------
# Baseline per-fold validation evaluation
# --------------------------------------------------
baseline_rows = []
if RUN_BASELINE:
    baseline_cfg = OmegaConf.load(BASELINE_CONFIG)
    if 'data' in baseline_cfg:
        if 'metadata_csv' in baseline_cfg.data:
            baseline_cfg.data.metadata_csv = str(REPO_ROOT / baseline_cfg.data.metadata_csv)
        if 'raw_eeg' in baseline_cfg.data and 'base_dir' in baseline_cfg.data.raw_eeg:
            baseline_cfg.data.raw_eeg.base_dir = str(REPO_ROOT / baseline_cfg.data.raw_eeg.base_dir)
    if 'checkpointing' in baseline_cfg and 'dirpath' in baseline_cfg.checkpointing:
        baseline_cfg.checkpointing.dirpath = str(REPO_ROOT / baseline_cfg.checkpointing.dirpath)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dm_full = RawEEGDataModule(baseline_cfg)
    dm_full.prepare_data(); dm_full.setup('fit')
    patient_ids = np.array(dm_full.patient_ids())
    from sklearn.model_selection import GroupKFold
    gkf = GroupKFold(n_splits=BASELINE_N_SPLITS)

    ckpt_root = Path(getattr(baseline_cfg.checkpointing, 'dirpath', REPO_ROOT/'artifacts/checkpoints'))
    fold_ckpts = {}
    for f in range(BASELINE_N_SPLITS):
        fold_dir = ckpt_root / f"notebook_fold_{f}"
        if not fold_dir.exists():
            continue
        ckpts = list(fold_dir.rglob('*.ckpt'))
        if not ckpts:
            continue
        with_parsed = [(c, parse_loss_from_name(c)) for c in ckpts]
        parsed = [p for p in with_parsed if p[1] is not None]
        if parsed:
            best = min(parsed, key=lambda t: t[1])[0]
        else:
            best = max(ckpts, key=lambda p: p.stat().st_mtime)
        fold_ckpts[f] = best

    for fold, (_, val_idx) in enumerate(gkf.split(patient_ids, groups=patient_ids)):
        if fold not in fold_ckpts:
            print(f"[Baseline] No checkpoint for fold {fold}, skipping")
            continue
        print(f"\n[Baseline] Evaluating fold {fold} from {fold_ckpts[fold]}")
        val_ids = [str(pid) for pid in patient_ids[val_idx]]
        val_ds = dm_full.dataset_for_patients(val_ids)
        val_loader = dm_full.dataloader(val_ds, shuffle=False)

        model = EEGMLPLightningModule.load_from_checkpoint(str(fold_ckpts[fold]), config=baseline_cfg, strict=False)
        model.eval().to(device)

        loss_type = (baseline_cfg.loss.type or 'kl').lower()
        total_loss, n_items = 0.0, 0
        all_probs, all_targets = [], []
        with torch.no_grad():
            for b in val_loader:
                x = b['eeg_signal'].to(device)
                t = b['target'].to(device)
                logits = model(x)
                if loss_type == 'kl':
                    logp = torch.log_softmax(logits, dim=-1)
                    loss = torch.nn.functional.kl_div(logp, t, reduction='batchmean')
                    probs = torch.softmax(logits, dim=-1)
                    hard_t = t.argmax(dim=1)
                else:
                    loss = torch.nn.functional.cross_entropy(logits, t.argmax(dim=1))
                    probs = torch.softmax(logits, dim=-1)
                    hard_t = t.argmax(dim=1)
                total_loss += float(loss) * x.size(0)
                n_items += x.size(0)
                all_probs.append(probs.detach().cpu().numpy())
                all_targets.append(hard_t.detach().cpu().numpy())
        val_loss = total_loss / max(1, n_items)
        all_probs = np.vstack(all_probs)
        y_true = np.concatenate(all_targets)
        y_pred = all_probs.argmax(axis=1)
        acc = float((y_pred == y_true).mean())
        f1 = macro_f1(y_true, y_pred)
        ece = ece_score(all_probs, y_true)
        cm = confusion_matrix(y_true, y_pred, n_classes=6)
        baseline_rows.append({'model':'baseline','fold':fold,'val_loss':val_loss,'acc_macro':acc,'f1_macro':f1,'ece':ece,'n':int(n_items)})
        print(f"[Baseline] fold={fold} val_loss={val_loss:.4f} acc={acc:.3f} f1={f1:.3f} ece={ece:.3f}")
else:
    print('Baseline inference skipped.')

# --------------------------------------------------
# Graph per-fold validation evaluation
# --------------------------------------------------
graph_rows = []
if RUN_GRAPH:
    # 1) Direct checkpoint root discovery
    fold_to_ckpt = {}
    ckpt_root = REPO_ROOT / 'checkpoints'
    if ckpt_root.exists():
        for f in GRAPH_FOLDS:
            # directories created during training with pattern online-fold{f}-epoch=epoch=X-val_loss=val
            dirs = [d for d in ckpt_root.glob(f'online-fold{f}-epoch=epoch=*') if d.is_dir()]
            ckpts = []
            for d in dirs:
                ckpts.extend(d.glob('*.ckpt'))
            if ckpts:
                parsed = [(c, parse_loss_from_name(c)) for c in ckpts]
                parsed_valid = [p for p in parsed if p[1] is not None]
                if parsed_valid:
                    best = min(parsed_valid, key=lambda t: t[1])[0]
                else:
                    best = max(ckpts, key=lambda p: p.stat().st_mtime)
                fold_to_ckpt[f] = best

    # 2) WandB fallback if any folds missing
    missing_folds = [f for f in GRAPH_FOLDS if f not in fold_to_ckpt]
    wandb_root = REPO_ROOT / 'logs' / 'wandb'
    if missing_folds and wandb_root.exists():
        for run_dir in sorted(wandb_root.glob('run-*')):
            files_dir = run_dir / 'files'
            meta = files_dir / 'wandb-metadata.json'
            if not meta.exists():
                continue
            try:
                md = json.loads(meta.read_text())
            except Exception:
                continue
            run_name = md.get('name') or ''
            m = re.search(r"notebook-fold(\d+)", run_name)
            if not m:
                continue
            fold = int(m.group(1))
            if fold not in missing_folds:
                continue
            ckpts = list(files_dir.rglob('*.ckpt'))
            if not ckpts:
                continue
            parsed = [(c, parse_loss_from_name(c)) for c in ckpts]
            parsed_valid = [p for p in parsed if p[1] is not None]
            if parsed_valid:
                best = min(parsed_valid, key=lambda t: t[1])[0]
            else:
                best = max(ckpts, key=lambda p: p.stat().st_mtime)
            fold_to_ckpt[fold] = best

    if not fold_to_ckpt:
        print('[Graph] No checkpoints discovered for requested folds.')
    # 3) Evaluate discovered checkpoints using Trainer.test with loaded model
    for f in sorted(fold_to_ckpt.keys()):
        ckpt_path = fold_to_ckpt[f]
        print(f"\n[Graph] Evaluating fold {f} from {ckpt_path}")
        cfg = OmegaConf.load(GRAPH_CONFIG_SRC if Path(GRAPH_CONFIG_SRC).exists() else GRAPH_CONFIG_NOTEBOOK)
        if 'data' in cfg:
            if 'data_dir' in cfg.data:
                cfg.data.data_dir = str(REPO_ROOT / cfg.data.data_dir)
            if 'train_csv' in cfg.data:
                cfg.data.train_csv = str(REPO_ROOT / cfg.data.train_csv)
        cfg.data.current_fold = f
        dm = HMSDataModule(
            data_dir=cfg.data.data_dir,
            train_csv=cfg.data.train_csv,
            batch_size=cfg.batch_size,
            n_folds=cfg.data.n_folds,
            current_fold=cfg.data.current_fold,
            stratify_by_class=cfg.data.get('stratify_by_class', True),
            stratify_by_evaluators=cfg.data.get('stratify_by_evaluators', False),
            evaluator_bins=cfg.data.get('evaluator_bins', [0,5,10,15,20,999]),
            min_evaluators=cfg.data.get('min_evaluators', 0),
            num_workers=0,
            pin_memory=True,
            prefetch_factor=None,
            shuffle_seed=cfg.data.shuffle_seed,
            preload_patients=True,
            use_cache_server=False,
        )
        dm.setup('fit')
        accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
        trainer = pl.Trainer(accelerator=accelerator, devices=1, logger=False, enable_progress_bar=False, precision='bf16-mixed' if torch.cuda.is_available() else 32)
        # Load the LightningModule from checkpoint (detect module class)
        model = None
        for cls in (HMSLightningModule, HMSEEGOnlyLightningModule):
            try:
                model = cls.load_from_checkpoint(str(ckpt_path))
                break
            except Exception:
                model = None
        if model is None:
            print(f"[Graph] Could not load model from {ckpt_path}; skipping fold {f}")
            continue
        results = trainer.test(model=model, datamodule=dm)
        if results and isinstance(results, list):
            r = results[0]
            val_loss = float(r.get('test/loss_epoch') or r.get('test/loss') or np.nan)
            acc = float(r.get('test/acc') or r.get('test/acc_macro') or np.nan)
        else:
            val_loss, acc = np.nan, np.nan
        graph_rows.append({'model':'graph','fold':f,'val_loss':val_loss,'acc':acc})
        print(f"[Graph] fold={f} val_loss={val_loss:.4f} acc={acc:.3f}")
else:
    print('Graph inference skipped.')

# --------------------------------------------------
# Aggregate & display
# --------------------------------------------------
rows = []
if baseline_rows:
    rows.extend(baseline_rows)
if graph_rows:
    rows.extend(graph_rows)

if rows:
    summary_df = pd.DataFrame(rows)
    display(summary_df)
    agg_rows = []
    for model_name, sub in summary_df.groupby('model'):
        if 'val_loss' in sub:
            agg_rows.append({
                'model': model_name,
                'val_loss_mean': float(sub['val_loss'].mean()),
                'val_loss_std': float(sub['val_loss'].std(ddof=1)) if len(sub)>1 else 0.0,
                'n_folds': int(sub['fold'].nunique()),
                'acc_mean': float(sub['acc_macro'].mean() if 'acc_macro' in sub else sub.get('acc', pd.Series(dtype=float)).mean() if 'acc' in sub else np.nan),
                'f1_mean': float(sub['f1_macro'].mean() if 'f1_macro' in sub else np.nan),
                'ece_mean': float(sub['ece'].mean() if 'ece' in sub else np.nan),
            })
    agg_df = pd.DataFrame(agg_rows)
    print('\nAggregate (per model):')
    display(agg_df)
else:
    print('No evaluation rows collected (check checkpoints presence).')


[Baseline] Evaluating fold 0 from /workspace/Kaggle-HMS/artifacts/checkpoints/notebook_fold_0/last.ckpt
[Baseline] fold=0 val_loss=1.7420 acc=0.341 f1=0.234 ece=0.287

[Baseline] Evaluating fold 1 from /workspace/Kaggle-HMS/artifacts/checkpoints/notebook_fold_1/last.ckpt
[Baseline] fold=1 val_loss=1.6297 acc=0.385 f1=0.242 ece=0.241

[Baseline] Evaluating fold 2 from /workspace/Kaggle-HMS/artifacts/checkpoints/notebook_fold_2/last.ckpt
[Baseline] fold=2 val_loss=1.6198 acc=0.387 f1=0.250 ece=0.228

[Baseline] Evaluating fold 3 from /workspace/Kaggle-HMS/artifacts/checkpoints/notebook_fold_3/last.ckpt
[Baseline] fold=3 val_loss=1.5360 acc=0.414 f1=0.257 ece=0.193

[Baseline] Evaluating fold 4 from /workspace/Kaggle-HMS/artifacts/checkpoints/notebook_fold_4/last.ckpt
[Baseline] fold=4 val_loss=1.6377 acc=0.373 f1=0.253 ece=0.248

[Graph] Evaluating fold 0 from /workspace/Kaggle-HMS/checkpoints/online-fold0-epoch=epoch=01-val_loss=val/loss_epoch=1.4006.ckpt
Stratifying by evaluator bins 

Preloading patients: 100%|██████████| 1459/1459 [05:14<00:00,  4.64it/s]
Preloading patients: 100%|██████████| 491/491 [01:15<00:00,  6.49it/s]
Using bfloat16 Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



Dataset Setup - Fold 0/3:
  Train: 1459 patients, 16132 samples
  Val:   491 patients, 4051 samples

  Stratification by evaluator bins:
    Train: {0: 10965, 1: 212, 2: 3646, 3: 1234, 4: 75}
    Val: {0: 2726, 1: 47, 2: 909, 3: 349, 4: 20}

  Class weights: [0.6017259359359741, 0.6564667224884033, 1.2223418951034546, 1.9777363538742065, 1.240186333656311, 0.3015425503253937]

Stratifying by evaluator bins only


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=511` in the `DataLoader` to improve performance.


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/acc            0.46087387204170227
     test/loss_epoch        1.4259835481643677
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[Graph] fold=0 val_loss=1.4260 acc=0.461


Unnamed: 0,model,fold,val_loss,acc_macro,f1_macro,ece,n,acc
0,baseline,0,1.742001,0.340847,0.234105,0.28706,4392.0,
1,baseline,1,1.629694,0.385027,0.242128,0.241044,4488.0,
2,baseline,2,1.619837,0.386801,0.249727,0.22844,4470.0,
3,baseline,3,1.536041,0.414337,0.257192,0.192687,3041.0,
4,baseline,4,1.637708,0.373154,0.253255,0.247707,3792.0,
5,graph,0,1.425984,,,,,0.460874



Aggregate (per model):


Unnamed: 0,model,val_loss_mean,val_loss_std,n_folds,acc_mean,f1_mean,ece_mean
0,baseline,1.633056,0.073295,5,0.380033,0.247282,0.239388
1,graph,1.425984,0.0,1,,,


## Test vs Validation Overview (Recap)
All fold metrics shown in this notebook are computed on each fold's validation split.

- Baseline model: manual forward pass over the validation dataset for that fold.
- Graph model: `Trainer.test(...)` is invoked; the HMS DataModule maps test to validation.
- Kaggle competition test set (unlabeled) is NOT used here because labels are unavailable.

Why reuse validation as test? This pattern standardizes evaluation API calls and lets us log
metrics under the `test/` prefix for Lightning while still comparing folds fairly.

If you later need true hold-out performance beyond cross-validation, you'd reserve one fold
(or a stratified slice) exclusively for final evaluation and not use it for model selection.

In [18]:
# Baseline vs Graph Combined Comparison & Deltas
import pandas as pd, numpy as np
if 'summary_df' not in globals():
    print('summary_df not found; run the inference cell above first.')
else:
    base = summary_df[summary_df.model=='baseline'].copy()
    graph = summary_df[summary_df.model=='graph'].copy()
    if base.empty or graph.empty:
        print('Missing baseline or graph rows; cannot compare.')
    else:
        # Harmonize column names
        base_ren = base.rename(columns={'val_loss':'baseline_val_loss','acc_macro':'baseline_acc'})
        graph_ren = graph.rename(columns={'val_loss':'graph_val_loss','acc':'graph_acc'})
        merged = pd.merge(base_ren[['fold','baseline_val_loss','baseline_acc']], graph_ren[['fold','graph_val_loss','graph_acc']], on='fold', how='inner').sort_values('fold')
        if merged.empty:
            print('No overlapping folds to compare.')
        else:
            merged['delta_loss'] = merged['graph_val_loss'] - merged['baseline_val_loss']  # negative better
            merged['delta_acc'] = merged['graph_acc'] - merged['baseline_acc']            # positive better
            display(merged)
            agg = {
                'folds_compared': int(merged.fold.nunique()),
                'baseline_loss_mean': float(merged.baseline_val_loss.mean()),
                'graph_loss_mean': float(merged.graph_val_loss.mean()),
                'delta_loss_mean': float(merged.delta_loss.mean()),
                'baseline_acc_mean': float(merged.baseline_acc.mean()),
                'graph_acc_mean': float(merged.graph_acc.mean()),
                'delta_acc_mean': float(merged.delta_acc.mean()),
            }
            print('\nAggregate deltas (graph - baseline):')
            display(pd.DataFrame([agg]))
            # Optional: export
            merged.to_csv('comparison_folds.csv', index=False)
            pd.DataFrame([agg]).to_csv('comparison_aggregate.csv', index=False)
            print('Saved comparison_folds.csv and comparison_aggregate.csv')

Unnamed: 0,fold,baseline_val_loss,baseline_acc,graph_val_loss,graph_acc,delta_loss,delta_acc
0,0,1.742001,0.340847,1.425984,0.460874,-0.316018,0.120027



Aggregate deltas (graph - baseline):


Unnamed: 0,folds_compared,baseline_loss_mean,graph_loss_mean,delta_loss_mean,baseline_acc_mean,graph_acc_mean,delta_acc_mean
0,1,1.742001,1.425984,-0.316018,0.340847,0.460874,0.120027


Saved comparison_folds.csv and comparison_aggregate.csv


In [4]:
# Holdout configuration constants (idempotent)
from pathlib import Path
NOTEBOOK_CWD = Path.cwd()
REPO_ROOT = NOTEBOOK_CWD.parent if NOTEBOOK_CWD.name == 'notebooks' else NOTEBOOK_CWD
RAW_ROOT = REPO_ROOT / 'data' / 'raw'
HOLDOUT_SEED = 42
HOLDOUT_FRAC = 0.20  # fraction of patients for holdout
HOLDOUT_TEST_CSV = RAW_ROOT / 'test_holdout.csv'
HOLDOUT_TRAIN_EXCL_CSV = RAW_ROOT / 'train_unique_excl_holdout.csv'
PRIMARY_TRAIN_UNIQUE = RAW_ROOT / 'train_unique.csv'
print('Repo root:', REPO_ROOT)
print('Raw root:', RAW_ROOT)
print('Holdout test CSV path:', HOLDOUT_TEST_CSV)
print('Holdout train-excl CSV path:', HOLDOUT_TRAIN_EXCL_CSV)

Repo root: /workspace/Kaggle-HMS
Raw root: /workspace/Kaggle-HMS/data/raw
Holdout test CSV path: /workspace/Kaggle-HMS/data/raw/test_holdout.csv
Holdout train-excl CSV path: /workspace/Kaggle-HMS/data/raw/train_unique_excl_holdout.csv


In [5]:
# Create holdout CSVs from train_unique.csv (patient-level split)
import pandas as pd, numpy as np, random
from pathlib import Path

# Paths from constants cell
NOTEBOOK_CWD = Path.cwd()
REPO_ROOT = NOTEBOOK_CWD.parent if NOTEBOOK_CWD.name == 'notebooks' else NOTEBOOK_CWD
RAW_ROOT = REPO_ROOT / 'data' / 'raw'

PRIMARY_TRAIN_UNIQUE = RAW_ROOT / 'train_unique.csv'
HOLDOUT_TEST_CSV = RAW_ROOT / 'test_holdout.csv'
HOLDOUT_TRAIN_EXCL_CSV = RAW_ROOT / 'train_unique_excl_holdout.csv'

# Config with defaults if not defined
try:
    HOLDOUT_SEED
except NameError:
    HOLDOUT_SEED = 42
try:
    HOLDOUT_FRAC
except NameError:
    HOLDOUT_FRAC = 0.20

assert PRIMARY_TRAIN_UNIQUE.exists(), f'Missing {PRIMARY_TRAIN_UNIQUE}'
df = pd.read_csv(PRIMARY_TRAIN_UNIQUE)
assert 'patient_id' in df.columns and 'label_id' in df.columns, 'train_unique.csv missing required columns'

patients = sorted(df['patient_id'].unique())
rng = random.Random(int(HOLDOUT_SEED))
rng.shuffle(patients)
cut = max(1, int(len(patients) * float(HOLDOUT_FRAC)))
holdout_patients = set(patients[:cut])

holdout_df = df[df['patient_id'].isin(holdout_patients)].copy()
train_excl_df = df[~df['patient_id'].isin(holdout_patients)].copy()

# Save with same schema
HOLDOUT_TEST_CSV.parent.mkdir(parents=True, exist_ok=True)
holdout_df.to_csv(HOLDOUT_TEST_CSV, index=False)
train_excl_df.to_csv(HOLDOUT_TRAIN_EXCL_CSV, index=False)

print('Holdout CSVs written:')
print('  Holdout (test):', HOLDOUT_TEST_CSV, len(holdout_df), 'rows, patients:', holdout_df['patient_id'].nunique())
print('  Train excl:', HOLDOUT_TRAIN_EXCL_CSV, len(train_excl_df), 'rows, patients:', train_excl_df['patient_id'].nunique())

Holdout CSVs written:
  Holdout (test): /workspace/Kaggle-HMS/data/raw/test_holdout.csv 3780 rows, patients: 390
  Train excl: /workspace/Kaggle-HMS/data/raw/train_unique_excl_holdout.csv 16403 rows, patients: 1560


In [10]:
# Holdout evaluation for baseline (soft & hard) and graph models
from pathlib import Path
import pandas as pd, numpy as np, torch
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from src.data.raw_eeg_dataset import RawEEGDataset, raw_eeg_collate_fn
from src.data.graph_dataset import HMSDataset, collate_graphs
from src.lightning_trainer.mlp_lightning_module import EEGMLPLightningModule
from src.lightning_trainer.graph_lightning_module import HMSLightningModule
import torch.nn.functional as F

NOTEBOOK_CWD = Path.cwd()
REPO_ROOT = NOTEBOOK_CWD.parent if NOTEBOOK_CWD.name == 'notebooks' else NOTEBOOK_CWD
RAW_ROOT = REPO_ROOT / 'data' / 'raw'
# Reuse constants from config cell (if rerun order different, redefine safely)
HOLDOUT_TEST_CSV = RAW_ROOT / 'test_holdout.csv'
HOLDOUT_TRAIN_EXCL_CSV = RAW_ROOT / 'train_unique_excl_holdout.csv'
assert HOLDOUT_TEST_CSV.exists() and HOLDOUT_TRAIN_EXCL_CSV.exists(), 'Holdout CSVs missing; run split cell first.'

# Load metadata
holdout_df = pd.read_csv(HOLDOUT_TEST_CSV)
train_excl_df = pd.read_csv(HOLDOUT_TRAIN_EXCL_CSV)
print('Holdout rows:', len(holdout_df), 'Train excl rows:', len(train_excl_df))

# ---------- Baseline Evaluation (Soft + Hard) ----------
base_cfg = OmegaConf.load(BASELINE_CONFIG)
# Patch metadata path to train exclusion for any training references (not strictly needed for pure eval)
base_cfg.data.metadata_csv = str(HOLDOUT_TRAIN_EXCL_CSV)
vote_keys = list(getattr(base_cfg.data, 'vote_keys', ['seizure_vote','lpd_vote','gpd_vote','lrda_vote','grda_vote','other_vote']))
# Robust consensus mapping (capitalize first letter only if not already uppercase entire token)
consensus_labels = ['Seizure','LPD','GPD','LRDA','GRDA','Other']
label_to_index = {lab: i for i, lab in enumerate(consensus_labels)}

# Build label metadata mapping for RawEEGDataset
meta_rows = {}
for r in holdout_df.to_dict('records'):
    lid = str(r['label_id'])
    votes = {k: float(r.get(k, 0.0)) for k in vote_keys if k in r}
    consensus = r.get('expert_consensus')
    # Normalize consensus capitalization
    if isinstance(consensus, str):
        up = consensus.upper()
        if up in ['SEIZURE','LPD','GPD','LRDA','GRDA','OTHER']:
            # Map uppercase back to canonical
            mapping_back = {'SEIZURE':'Seizure','LPD':'LPD','GPD':'GPD','LRDA':'LRDA','GRDA':'GRDA','OTHER':'Other'}
            consensus = mapping_back[up]
    meta_rows[lid] = {'votes': votes, 'expert_consensus': consensus}

records = []
for r in holdout_df.to_dict('records'):
    records.append({
        'patient_id': r['patient_id'],
        'label_id': str(r['label_id']),
        'eeg_id': int(r['eeg_id']),
        'offset_seconds': float(r.get('eeg_label_offset_seconds', 0.0)),
    })

# Common raw EEG params
raw_base_dir = RAW_ROOT
split = 'train'  # holdout derived from train_unique
sr = int(getattr(base_cfg.data.raw_eeg, 'sampling_rate', 200))
label_sec = float(getattr(base_cfg.data.raw_eeg, 'label_window_sec', 10.0))
context_sec = float(getattr(base_cfg.data.raw_eeg, 'context_window_sec', 50.0))
normalize = bool(getattr(base_cfg.data.raw_eeg, 'normalize', True))

# Soft labels dataset
soft_ds = RawEEGDataset(records, raw_base_dir=raw_base_dir, split=split, sampling_rate=sr,
                        label_window_sec=label_sec, context_window_sec=context_sec, normalize=normalize,
                        label_metadata=meta_rows, vote_keys=vote_keys, target_mode='votes', label_to_index=label_to_index)
# Hard labels dataset
hard_ds = RawEEGDataset(records, raw_base_dir=raw_base_dir, split=split, sampling_rate=sr,
                        label_window_sec=label_sec, context_window_sec=context_sec, normalize=normalize,
                        label_metadata=meta_rows, vote_keys=vote_keys, target_mode='consensus', label_to_index=label_to_index)

soft_loader = DataLoader(soft_ds, batch_size=64, shuffle=False, collate_fn=raw_eeg_collate_fn)
hard_loader = DataLoader(hard_ds, batch_size=64, shuffle=False, collate_fn=raw_eeg_collate_fn)

# Discover baseline checkpoints (use the 5 folds trained earlier)
ckpt_root = REPO_ROOT / 'artifacts' / 'checkpoints'
baseline_ckpts = sorted(ckpt_root.glob('notebook_fold_*/*.ckpt'))
assert baseline_ckpts, f'No baseline checkpoints found under {ckpt_root}'
print('Found baseline checkpoints:', len(baseline_ckpts))

# Evaluate soft & hard per checkpoint and aggregate
soft_metrics = []
hard_metrics = []

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for ckpt in baseline_ckpts:
    model = EEGMLPLightningModule.load_from_checkpoint(str(ckpt), config=base_cfg, strict=False)
    model.eval().to(device)
    # Soft
    total_loss = 0.0
    total_acc = 0.0
    total_n = 0
    with torch.no_grad():
        for batch in soft_loader:
            x = batch['eeg_signal'].to(device)
            t = batch['target'].to(device)
            logits = model(x)
            logp = torch.log_softmax(logits, dim=-1)
            loss = F.kl_div(logp, t, reduction='batchmean')
            pred = logits.argmax(dim=-1)
            true = t.argmax(dim=-1)
            acc = (pred == true).float().mean().item()
            total_loss += loss.item() * x.size(0)
            total_acc += acc * x.size(0)
            total_n += x.size(0)
    soft_metrics.append({'ckpt': ckpt.name, 'loss': total_loss/total_n, 'acc': total_acc/total_n})
    # Hard
    total_loss = 0.0
    total_acc = 0.0
    total_n = 0
    with torch.no_grad():
        for batch in hard_loader:
            x = batch['eeg_signal'].to(device)
            t = batch['target'].to(device)
            logits = model(x)
            loss = F.cross_entropy(logits, t)
            pred = logits.argmax(dim=-1)
            acc = (pred == t).float().mean().item()
            total_loss += loss.item() * x.size(0)
            total_acc += acc * x.size(0)
            total_n += x.size(0)
    hard_metrics.append({'ckpt': ckpt.name, 'loss': total_loss/total_n, 'acc': total_acc/total_n})

soft_df = pd.DataFrame(soft_metrics)
hard_df = pd.DataFrame(hard_metrics)
print('\nBaseline Holdout Soft Metrics (KL on probs):')
print(soft_df.head())
print('Soft mean loss:', soft_df['loss'].mean(), 'acc:', soft_df['acc'].mean())
print('\nBaseline Holdout Hard Metrics (CE on class):')
print(hard_df.head())
print('Hard mean loss:', hard_df['loss'].mean(), 'acc:', hard_df['acc'].mean())

# ---------- Graph Model Holdout Evaluation (using best available ckpt per fold) ----------
checkpoints_dir = REPO_ROOT / 'checkpoints'
all_graph_ckpts = list(checkpoints_dir.glob('online-fold*-epoch=epoch=*/*.ckpt'))
assert all_graph_ckpts, f'No graph checkpoints found under {checkpoints_dir}'
print('Graph ckpts found:', len(all_graph_ckpts))
# Parse loss from filename pattern loss_epoch=*.ckpt if present
fold_best = {}
for ck in all_graph_ckpts:
    loss_val = None
    if 'loss_epoch=' in ck.name:
        try:
            part = ck.name.split('loss_epoch=')[1].split('.ckpt')[0]
            loss_val = float(part)
        except Exception:
            pass
    fold_id = None
    for token in ck.parts:
        if token.startswith('online-fold'):
            # token like online-fold0-epoch=epoch=04-val_loss=val
            seg = token.replace('online-fold','').split('-')[0]
            try:
                fold_id = int(seg)
            except Exception:
                fold_id = 0
    if fold_id is None:
        continue
    prev = fold_best.get(fold_id)
    if prev is None or (loss_val is not None and loss_val < prev[1]):
        fold_best[fold_id] = (ck, loss_val if loss_val is not None else 1e9)

# Graph holdout dataset
graph_meta = holdout_df.copy()
gds = HMSDataset(data_dir=REPO_ROOT/'data'/'processed', metadata_df=graph_meta, is_train=False, preload_patients=False)
gloader = DataLoader(gds, batch_size=2, shuffle=False, collate_fn=collate_graphs)

graph_holdout_rows = []
for fold_id,(ck,_) in fold_best.items():
    gcfg = OmegaConf.load(GRAPH_CONFIG_NOTEBOOK)
    gmodel = HMSLightningModule.load_from_checkpoint(str(ck), config=gcfg, strict=False)
    gmodel.eval().to(device)
    total_loss = 0.0
    total_acc = 0.0
    total_n = 0
    with torch.no_grad():
        for batch in gloader:
            eeg_graphs = [g.to(device) for g in batch['eeg_graphs']]
            spec_graphs = [g.to(device) for g in batch['spec_graphs']]
            targets = batch['targets'].to(device)
            logits = gmodel(eeg_graphs, spec_graphs)
            loss = F.cross_entropy(logits, targets)
            pred = logits.argmax(dim=-1)
            acc = (pred == targets).float().mean().item()
            bs = targets.size(0)
            total_loss += loss.item() * bs
            total_acc += acc * bs
            total_n += bs
    graph_holdout_rows.append({'fold': fold_id, 'ckpt': ck.name, 'loss': total_loss/total_n, 'acc': total_acc/total_n})

graph_holdout_df = pd.DataFrame(graph_holdout_rows)
print('\nGraph Holdout Metrics (first 5 rows):')
print(graph_holdout_df.head())
print('Graph mean loss:', graph_holdout_df['loss'].mean(), 'acc:', graph_holdout_df['acc'].mean())

# Combined summary
summary = {
    'baseline_soft_loss_mean': soft_df['loss'].mean(),
    'baseline_soft_acc_mean': soft_df['acc'].mean(),
    'baseline_hard_loss_mean': hard_df['loss'].mean(),
    'baseline_hard_acc_mean': hard_df['acc'].mean(),
    'graph_loss_mean': graph_holdout_df['loss'].mean(),
    'graph_acc_mean': graph_holdout_df['acc'].mean(),
}
print('\nHoldout Summary:', summary)


Holdout rows: 3780 Train excl rows: 16403
Found baseline checkpoints: 20

Baseline Holdout Soft Metrics (KL on probs):
                    ckpt      loss       acc
0  epoch=0-step=329.ckpt  1.243473  0.346561
1  epoch=1-step=658.ckpt  1.187407  0.358995
2  epoch=2-step=987.ckpt  1.116365  0.432011
3              last.ckpt  0.305570  0.865079
4  epoch=1-step=654.ckpt  1.207363  0.358995
Soft mean loss: 0.9721764940015538 acc: 0.5101058201058202

Baseline Holdout Hard Metrics (CE on class):
                    ckpt      loss       acc
0  epoch=0-step=329.ckpt  1.561967  0.346561
1  epoch=1-step=658.ckpt  1.500826  0.358995
2  epoch=2-step=987.ckpt  1.420332  0.432011
3              last.ckpt  0.470858  0.865079
4  epoch=1-step=654.ckpt  1.523987  0.358995
Hard mean loss: 1.2498565788149203 acc: 0.5101058201058202
Graph ckpts found: 3

Baseline Holdout Soft Metrics (KL on probs):
                    ckpt      loss       acc
0  epoch=0-step=329.ckpt  1.243473  0.346561
1  epoch=1-step=658.