In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-07 15:37:38.431[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /root/snn-encoder-test[0m


In [3]:
import gc
import os

from loguru import logger
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import optuna
from sklearn.model_selection import KFold, train_test_split
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler

from eeg_snn_encoder.encoders import BSAEncoder
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig

In [5]:
# Initialize the Optuna study
study = optuna.load_study(
    study_name="model-tuning-bsa",
    storage=os.environ["OPTUNA_CONN_STRING"],
)

best_params = study.best_trial.params

model_params: ModelConfig = {
    "threshold": best_params["threshold"],
    "slope": best_params["slope"],
    "beta": best_params["beta"],
    "dropout_rate1": best_params["dropout_rate1"],
    "dropout_rate2": best_params["dropout_rate2"],
}

optimizer_params: OptimizerConfig = {
    "lr": best_params["lr"],
    "weight_decay": best_params["weight_decay"],
    "scheduler_factor": best_params["scheduler_factor"],
    "scheduler_patience": best_params["scheduler_patience"],
}

encoder_params = {
    "win_size": best_params["win_size"],
    "cutoff": best_params["cutoff"],
    "threshold": best_params["encoder_threshold"],
}

In [6]:
kfold = KFold(n_splits=5, shuffle=True, random_state=42)

fold_results = []

for fold, (train_val_ids, test_ids) in enumerate(kfold.split(dataset)):
    logger.info(f"Starting fold {fold + 1} of {kfold.n_splits} BSA")

    train_ids, val_ids = train_test_split(
        train_val_ids, test_size=0.2, random_state=42, shuffle=True
    )

    train_sampler = SubsetRandomSampler(train_ids)
    val_sampler = SubsetRandomSampler(val_ids)
    test_sampler = SubsetRandomSampler(test_ids)

    trainloader = DataLoader(dataset, batch_size=512, sampler=train_sampler)
    valloader = DataLoader(dataset, batch_size=512, sampler=val_sampler)
    testloader = DataLoader(dataset, batch_size=512, sampler=test_sampler)

    spike_encoder = BSAEncoder(**encoder_params)

    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=50,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_checkpointing=False,
        callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, trainloader, valloader)
    trainer.test(lit_model, testloader)

    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"]
    test_precision = trainer.callback_metrics["test_precision"]
    test_recall = trainer.callback_metrics["test_recall"]
    test_f1 = trainer.callback_metrics["test_f1"]
    test_mse = trainer.callback_metrics["test_mse"]
    test_total_spikes = trainer.callback_metrics["test_total_spikes"]

    logger.info(
        f"Fold {fold + 1} - Test Loss: {test_loss:.4f}, "
        f"Test Accuracy: {test_acc:.4f}, "
        f"Test Precision: {test_precision:.4f}, "
        f"Test Recall: {test_recall:.4f}, "
        f"Test F1: {test_f1:.4f}, "
        f"Test MSE: {test_mse:.4f}, "
        f"Test Total Spikes: {test_total_spikes:.4f}"
    )

    fold_results.append(
        {
            "fold": fold,
            "test_loss": test_loss,
            "test_acc": test_acc,
            "test_precision": test_precision,
            "test_recall": test_recall,
            "test_f1": test_f1,
            "test_mse": test_mse,
            "test_total_spikes": test_total_spikes,
        }
    )

    del lit_model
    del trainer
    del spike_encoder
    torch.cuda.empty_cache()
    gc.collect()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-07 15:38:00.699[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mStarting fold 1 of 5 BSA[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 15:40:40.849[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m49[0m - [1mFold 1 - Test Loss: 8.6549, Test Accuracy: 0.8247, Test Precision: 0.8614, Test Recall: 0.7848, Test F1: 0.8212, Test MSE: 0.1753, Test Total Spikes: 30599.1270[0m


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-07 15:40:41.243[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mStarting fold 2 of 5 BSA[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 15:42:51.950[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m49[0m - [1mFold 2 - Test Loss: 9.3794, Test Accuracy: 0.7918, Test Precision: 0.7560, Test Recall: 0.8731, Test F1: 0.8103, Test MSE: 0.2082, Test Total Spikes: 31216.5020[0m


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-07 15:42:52.356[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mStarting fold 3 of 5 BSA[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 15:44:43.022[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m49[0m - [1mFold 3 - Test Loss: 8.9918, Test Accuracy: 0.8088, Test Precision: 0.7877, Test Recall: 0.8357, Test F1: 0.8108, Test MSE: 0.1912, Test Total Spikes: 30944.6934[0m


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-07 15:44:43.430[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mStarting fold 4 of 5 BSA[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 15:46:25.744[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m49[0m - [1mFold 4 - Test Loss: 8.9981, Test Accuracy: 0.8265, Test Precision: 0.8361, Test Recall: 0.8200, Test F1: 0.8277, Test MSE: 0.1735, Test Total Spikes: 30728.1738[0m


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-07 15:46:26.150[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mStarting fold 5 of 5 BSA[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 15:48:35.727[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m49[0m - [1mFold 5 - Test Loss: 9.9099, Test Accuracy: 0.7926, Test Precision: 0.7552, Test Recall: 0.8376, Test F1: 0.7941, Test MSE: 0.2074, Test Total Spikes: 30505.6680[0m


In [7]:
import pandas as pd

results_df = pd.DataFrame(fold_results)
tensor_cols = results_df.select_dtypes(include=["object"]).columns

for col in tensor_cols:
    results_df[col] = results_df[col].apply(lambda x: x.item() if hasattr(x, "item") else x)

results_df.set_index("fold", inplace=True)
results_df

Unnamed: 0_level_0,test_loss,test_acc,test_precision,test_recall,test_f1,test_mse,test_total_spikes
fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,8.654949,0.824701,0.861408,0.78481,0.821231,0.175299,30599.126953
1,9.379429,0.791833,0.756023,0.873119,0.810349,0.208167,31216.501953
2,8.991779,0.808765,0.787742,0.83571,0.810755,0.191235,30944.693359
3,8.998121,0.82652,0.836147,0.819981,0.827728,0.17348,30728.173828
4,9.909909,0.792622,0.755183,0.837561,0.794106,0.207378,30505.667969


In [8]:
from eeg_snn_encoder.config import REPORTS_DIR
import json

results_df.to_csv(
    REPORTS_DIR / "bsa_model_results.csv", index=True
)

with open(REPORTS_DIR / "bsa_model_param.json", "w") as f:
    params = {
        "model_params": model_params,
        "optimizer_params": optimizer_params,
        "encoder_params": encoder_params,
    }
    f.write(json.dumps(params, indent=2))