In [4]:
# Quick minimal preprocessing if no processed patients exist
from pathlib import Path
import pandas as pd
import torch
from omegaconf import OmegaConf
from src.data.utils.eeg_process import EEGGraphBuilder, select_eeg_channels
from src.data.utils.spectrogram_process import SpectrogramGraphBuilder, filter_spectrogram_columns

proc_dir = Path('data/processed')
proc_dir.mkdir(parents=True, exist_ok=True)
existing = list(proc_dir.glob('patient_*.pt'))
if len(existing) == 0:
    print('No processed patients found; generating a tiny sample for smoke training...')
    cfg = OmegaConf.load('configs/graphs.yaml')
    df = pd.read_csv(cfg.paths.train_csv)
    df.columns = df.columns.str.strip()
    # Pick first patient with at least 1-2 samples
    first_pid = int(df['patient_id'].iloc[0])
    df_pid = df[df['patient_id'] == first_pid].head(2).reset_index(drop=True)
    print('Using patient', first_pid, 'with', len(df_pid), 'samples')

    # Build graph builders
    eeg_builder = EEGGraphBuilder(
        sampling_rate=cfg.eeg.sampling_rate,
        window_size=cfg.eeg.window_size,
        stride=cfg.eeg.stride,
        bands=dict(cfg.eeg.bands),
        coherence_threshold=cfg.eeg.coherence.threshold,
        nperseg_factor=cfg.eeg.coherence.nperseg_factor,
        channels=list(cfg.eeg.channels),
        apply_bandpass=cfg.eeg.preprocessing.bandpass_filter.enabled,
        bandpass_low=cfg.eeg.preprocessing.bandpass_filter.lowcut,
        bandpass_high=cfg.eeg.preprocessing.bandpass_filter.highcut,
        bandpass_order=cfg.eeg.preprocessing.bandpass_filter.order,
        apply_notch=cfg.eeg.preprocessing.notch_filter.enabled,
        notch_freq=cfg.eeg.preprocessing.notch_filter.frequency,
        notch_q=cfg.eeg.preprocessing.notch_filter.quality_factor,
        apply_normalize=cfg.eeg.preprocessing.normalize.enabled,
    )
    spec_builder = SpectrogramGraphBuilder(
        window_size=cfg.spectrogram.window_size,
        stride=cfg.spectrogram.stride,
        regions=list(cfg.spectrogram.regions),
        bands=dict(cfg.spectrogram.bands),
        aggregation=cfg.spectrogram.aggregation,
        spatial_edges=cfg.spectrogram.spatial_edges,
        apply_preprocessing=cfg.spectrogram.preprocessing.enabled,
        clip_min=cfg.spectrogram.preprocessing.clip_min,
        clip_max=cfg.spectrogram.preprocessing.clip_max,
    )

    # Produce graphs for up to 2 labels and save in one patient file
    patient_data = {}
    for _, row in df_pid.iterrows():
        eeg_path = f"{cfg.paths.train_eegs}/{row.eeg_id}.parquet"
        eeg_df = pd.read_parquet(eeg_path)
        eeg_offset = int(row.eeg_label_offset_seconds)
        eeg_start = eeg_offset * 200
        eeg_end = eeg_start + 50 * 200
        eeg_window_df = eeg_df.iloc[eeg_start:eeg_end].copy()
        eeg_array = select_eeg_channels(eeg_window_df, list(cfg.eeg.channels))
        eeg_graphs = eeg_builder.process_eeg_signal(eeg_array)

        spec_path = f"{cfg.paths.train_spectrograms}/{row.spectrogram_id}.parquet"
        spec_df = pd.read_parquet(spec_path)
        spec_offset = int(row.spectrogram_label_offset_seconds)
        spec_window_df = spec_df[(spec_df['time'] >= spec_offset) & (spec_df['time'] < spec_offset + 600)].copy()
        spec_window_df = filter_spectrogram_columns(spec_window_df, list(cfg.spectrogram.regions))
        spec_graphs = spec_builder.process_spectrogram(spec_window_df)

        label_to_index = dict(cfg.label_to_index)
        target = label_to_index.get(str(row.expert_consensus).strip(), -1)
        if target == -1:
            continue
        patient_data[int(row.label_id)] = {
            'eeg_graphs': eeg_graphs,
            'spec_graphs': spec_graphs,
            'target': int(target),
        }

    out_path = proc_dir / f'patient_{first_pid}.pt'
    if patient_data:
        torch.save(patient_data, out_path)
        print('Saved minimal patient file:', out_path)
    else:
        print('Failed to create minimal patient data; please run full preprocessing.')
else:
    print('Found', len(existing), 'processed patient files; skipping minimal preprocessing.')

Found 1950 processed patient files; skipping minimal preprocessing.


In [5]:
# Smoke training (1 epoch, 2 train batches, 1 val batch)
import importlib, src.lightning_trainer.graph_lightning_module as glm
importlib.reload(glm)
import src.lightning_trainer as lt
importlib.reload(lt)
import src.train as tr
importlib.reload(tr)
from src.train import train
trainer, model, datamodule = train(
    config_path='configs/model.yaml',
    wandb_project='hms-graphs',
    wandb_name='smoke-notebook',
    smoke=True, offline=True,
    limit_train_batches=2, limit_val_batches=1, max_epochs_override=1,
    batch_size_override=2, num_workers_override=0,
)
print('Smoke training done.')


HMS Multi-Modal GNN Training
Config: configs/model.yaml
WandB Project: hms-graphs
WandB Run: smoke-notebook

[Info] WANDB_MODE=offline (no internet required)
Initializing DataModule...

Dataset Setup - Fold 0/4:
  Train: 1557 patients, 17358 samples
  Val:   393 patients, 2825 samples

  Stratification by evaluator bins:
    Train: {0: 11795, 1: 241, 2: 3902, 3: 1333, 4: 87}
    Val: {0: 1896, 1: 18, 2: 653, 3: 250, 4: 8}

  Class weights: None (skipped)

Initializing Model...

Model Architecture:
  EEG output dim:    256
  Spec output dim:   256
  Fusion output dim: 512
  Num classes:       6
  Total parameters:  2,019,846
  Trainable params:  2,019,846



[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


/venv/graph-ml-2/lib/python3.11/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Starting training...



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name                | Type               | Params | Mode 
-------------------------------------------------------------------
0 | model               | HMSMultiModalGNN   | 2.0 M  | train
1 | criterion           | CrossEntropyLoss   | 0      | train
2 | train_metrics       | MetricCollection   | 0      | train
3 | val_metrics         | MetricCollection   | 0      | train
4 | test_metrics        | MetricCollection   | 0      | train
5 | train_acc_per_class | MulticlassAccuracy | 0      | train
6 | val_acc_per_class   | MulticlassAccuracy | 0      | train
7 | test_acc_per_class  | MulticlassAccuracy | 0      | train
-------------------------------------------------------------------
2.0 M     Trainable pa


Dataset Setup - Fold 0/4:
  Train: 1557 patients, 17358 samples
  Val:   393 patients, 2825 samples

  Stratification by evaluator bins:
    Train: {0: 11795, 1: 241, 2: 3902, 3: 1333, 4: 87}
    Val: {0: 1896, 1: 18, 2: 653, 3: 250, 4: 8}

  Class weights: None (skipped)



/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 38. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved. New best score: 2.122
Epoch 0, global step 2: 'val/loss' reached 2.12158 (best 2.12158), saving model to '/workspace/Kaggle-HMS-2/checkpoints/hms-epoch=00-val/loss=2.1216.ckpt' as top 3
`Trainer.fit` stopped: `max_epochs=1` reached.



Testing best model...



Restoring states from the checkpoint path at /workspace/Kaggle-HMS-2/checkpoints/hms-epoch=00-val/loss=2.1216.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /workspace/Kaggle-HMS-2/checkpoints/hms-epoch=00-val/loss=2.1216.ckpt
/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

/venv/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 19. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/acc            0.10300885140895844
    test/acc_class_0        0.11460673809051514
    test/acc_class_1       0.012987012974917889
    test/acc_class_2                0.0
    test/acc_class_3        0.2631579041481018
    test/acc_class_4        0.6081632375717163
    test/acc_class_5       0.0017793594161048532
     test/acc_macro         0.08383481204509735
      test/f1_macro         0.07551635056734085
        test/loss                   nan
  test/precision_macro      0.07463131844997406
    test/recall_macro       0.08383481204509735
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Training Complete!
Best checkpoin

In [None]:
# Locate best checkpoint for inference
from pathlib import Path
ckpt_dir = Path('checkpoints')
best_ckpt = None
if ckpt_dir.exists():
    # Prefer latest by mtime; compatible with our Trainer checkpoint naming
    paths = sorted(ckpt_dir.glob('*.ckpt'), key=lambda p: p.stat().st_mtime, reverse=True)
    best_ckpt = str(paths[0]) if paths else None
print('Best checkpoint:', best_ckpt)
assert best_ckpt is not None, 'No checkpoints found after training.'

Best checkpoint: checkpoints/last.ckpt


In [None]:
# Inference on test split (from processed training data) with timing logs
import importlib, src.predict as sp
importlib.reload(sp)
from src.predict import predict
from pathlib import Path

pred_out = 'notebooks/outputs/preds_smoke.csv'
Path('notebooks/outputs').mkdir(parents=True, exist_ok=True)

# Use the best (latest) checkpoint discovered earlier
print('Using checkpoint:', best_ckpt)

predict(
    config_path='configs/model.yaml',
    checkpoint_path=best_ckpt,
    output_csv=pred_out,
    batch_size_override=2,
    num_workers_override=0,
    max_batches=2,  # limit for quick verification; remove for full inference
)
print('Saved predictions to', pred_out)


HMS Multi-Modal GNN Training
Config: configs/model.yaml
WandB Project: hms-graphs
WandB Run: smoke-notebook

[Info] WANDB_MODE=offline (no internet required)
Initializing DataModule...


/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.



Dataset Setup - Fold 0/4:
  Train: 1557 patients, 17358 samples
  Val:   393 patients, 2825 samples

  Stratification by evaluator bins:
    Train: {0: 11795, 1: 241, 2: 3902, 3: 1333, 4: 87}
    Val: {0: 1896, 1: 18, 2: 653, 3: 250, 4: 8}

  Class weights: None (skipped)

Initializing Model...

Model Architecture:
  EEG output dim:    256
  Spec output dim:   256
  Fusion output dim: 512
  Num classes:       6
  Total parameters:  2,019,846
  Trainable params:  2,019,846

Starting training...



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name                | Type               | Params | Mode 
-------------------------------------------------------------------
0 | model               | HMSMultiModalGNN   | 2.0 M  | train
1 | criterion           | CrossEntropyLoss   | 0      | train
2 | train_metrics       | MetricCollection   | 0      | train
3 | val_metrics         | MetricCollection   | 0      | train
4 | test_metrics        | MetricCollection   | 0      | train
5 | train_acc_per_class | MulticlassAccuracy | 0      | train
6 | val_acc_per_class   | MulticlassAccuracy | 0      | train
7 | test_acc_per_class  | MulticlassAccuracy | 0      | train
------------------------------------------------------------------


Dataset Setup - Fold 0/4:
  Train: 1557 patients, 17358 samples
  Val:   393 patients, 2825 samples

  Stratification by evaluator bins:
    Train: {0: 11795, 1: 241, 2: 3902, 3: 1333, 4: 87}
    Val: {0: 1896, 1: 18, 2: 653, 3: 250, 4: 8}

  Class weights: None (skipped)



/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=20` in the `DataLoader` to improve performance.
/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=20` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 38. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved. New best score: 1.917
Epoch 0, global step 2: 'val/loss' reached 1.91650 (best 1.91650), saving model to '/root/Kaggle-HMS-2/checkpoints/hms-epoch=00-val/loss=1.9165.ckpt' as top 3
Epoch 0, global step 2: 'val/loss' reached 1.91650 (best 1.91650), saving model to '/root/Kaggle-HMS-2/checkpoints/hms-epoch=00-val/loss=1.9165.ckpt' as top 3
`Trainer.fit` stopped: `max_epochs=1` reached.
`Trainer.fit` stopped: `max_epochs=1` reached.



Testing best model...



Restoring states from the checkpoint path at /root/Kaggle-HMS-2/checkpoints/hms-epoch=00-val/loss=1.9165.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /root/Kaggle-HMS-2/checkpoints/hms-epoch=00-val/loss=1.9165.ckpt
/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=20` in the `DataLoader` to improve performance.
Loaded model weights from the checkpoint at /root/Kaggle-HMS-2/checkpoints/hms-epoch=00-val/loss=1.9165.ckpt
/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` 

Testing: |          | 0/? [00:00<?, ?it/s]

/root/.local/share/mamba/envs/graph-ml-2/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 19. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/acc            0.35929203033447266
    test/acc_class_0        0.10898876190185547
    test/acc_class_1                0.0
    test/acc_class_2        0.06930693238973618
    test/acc_class_3                0.0
    test/acc_class_4                0.0
    test/acc_class_5        0.8042704463005066
     test/acc_macro         0.33421841263771057
      test/f1_macro         0.3019481897354126
        test/loss                   nan
  test/precision_macro      0.2935105264186859
    test/recall_macro       0.33421841263771057
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Training Complete!
Best checkpoint: /root/Kaggle-HM

In [3]:
# Setup & Imports
import os, sys
from pathlib import Path
project_root = Path.cwd()
sys.path.insert(0, str(project_root))
print('Project root:', project_root)

# WANDB offline for smoke runs
os.environ.setdefault('WANDB_MODE', 'offline')
print('WANDB_MODE =', os.environ.get('WANDB_MODE'))

# Quick data check
from pathlib import Path
proc_dir = Path('data/processed')
assert proc_dir.exists(), f'Missing processed data at {proc_dir}. Run preprocessing first.'
patient_files = list(proc_dir.glob('patient_*.pt'))
print('Processed patients:', len(patient_files))
assert len(patient_files) > 0, 'No processed patient files found.'

Project root: /root/Kaggle-HMS-2
WANDB_MODE = offline
Processed patients: 1950


# HMS Graphs: Smoke Train + Inference
#
# This notebook runs a fast smoke test training and a quick inference pass
# using the processed graph data in `data/processed/`.
# - Training: 1 epoch, 2 train batches, 1 val batch, WANDB offline
# - Inference: runs on the test split from processed training data
#
# After verifying this works, run full training from terminal (see notes below).