# Test best params from BSA on other Encoder
First we do the hyperparameter tuning for each encoder seperately using optuna. We using the TPESampler to do the tuning with 50 trials. In this file we using the best hyperparameter for each encoder to to train the model and do the cross validation using K-fold method.

In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

## Load the data

In [2]:
from eeg_snn_encoder.config import MODELS_DIR, PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-10 01:20:05.711[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/jupyter-group55/snn-encoder-test[0m


In [3]:
import gc
import os
import pandas as pd

from loguru import logger
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import optuna
from sklearn.model_selection import KFold, train_test_split
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler

from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
from eeg_snn_encoder.config import REPORTS_DIR

import json

In [4]:
from eeg_snn_encoder.encoders import PoissonEncoderExpand

torch.set_float32_matmul_precision('high')

## Load config and tuning
Load the best params from the intermediat value as optuna doesn't save the best value but latest one. And the tune the model and test it against the test set selected but K-Fold.

In [5]:
study_name = "model-fine-tuning-bsa"
study = optuna.load_study(
    study_name=study_name,
    storage=os.environ["OPTUNA_CONN_STRING"],
)

filtered_trials = [t for t in study.get_trials() if len(t.intermediate_values) > 35]

best_trial = min(
    filtered_trials,
    key=lambda t: min(t.intermediate_values.values(), default=float("inf")),
    default=None
)

best_params = best_trial.params

model_params: ModelConfig = {
    "threshold": best_params["threshold"],
    "slope": best_params["slope"],
    "beta": best_params["beta"],
    "dropout_rate1": best_params["dropout_rate1"],
    "dropout_rate2": best_params["dropout_rate2"],
}

optimizer_params: OptimizerConfig = {
    "lr": best_params["lr"],
    "weight_decay": best_params["weight_decay"],
    "scheduler_factor": best_params["scheduler_factor"],
    "scheduler_patience": best_params["scheduler_patience"],
}

encoder_params = {
    "interval_freq": 8,
    "random_seed": 47
}

In [6]:
kfold = KFold(n_splits=5, shuffle=True, random_state=42)

fold_results = []

batch_size = 128

for fold, (train_val_ids, test_ids) in enumerate(kfold.split(dataset)):
    logger.info(f"Starting fold {fold + 1} of {kfold.n_splits}")

    train_ids, val_ids = train_test_split(
        train_val_ids, test_size=0.2, random_state=42, shuffle=True
    )

    train_sampler = SubsetRandomSampler(train_ids)
    val_sampler = SubsetRandomSampler(val_ids)
    test_sampler = SubsetRandomSampler(test_ids)

    trainloader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    valloader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
    testloader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)

    spike_encoder = PoissonEncoderExpand(**encoder_params)

    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=50,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        default_root_dir=MODELS_DIR / f"test_encoder",
        callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, trainloader, valloader)
    trainer.test(lit_model, testloader, ckpt_path="best")

    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"].item()
    test_precision = trainer.callback_metrics["test_precision"].item()
    test_recall = trainer.callback_metrics["test_recall"].item()
    test_f1 = trainer.callback_metrics["test_f1"].item()
    test_mse = trainer.callback_metrics["test_mse"].item()
    test_total_spikes = trainer.callback_metrics["test_total_spikes"].item()

    logger.info(
        f"test-encoder ",
        f"Fold {fold + 1} - Test Loss: {test_loss:.4f}, "
        f"Test Accuracy: {test_acc:.4f}, "
        f"Test Precision: {test_precision:.4f}, "
        f"Test Recall: {test_recall:.4f}, "
        f"Test F1: {test_f1:.4f}, "
        f"Test MSE: {test_mse:.4f}, "
        f"Test Total Spikes: {test_total_spikes:.4f}"
    )

    fold_results.append(
        {
            "fold": fold,
            "test_loss": test_loss,
            "test_acc": test_acc,
            "test_precision": test_precision,
            "test_recall": test_recall,
            "test_f1": test_f1,
            "test_mse": test_mse,
            "test_total_spikes": test_total_spikes,
        }
    )

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/jupyter-group55/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-10 01:20:10.336[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mStarting fold 1 of 5[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/jupyter-group55/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/jupyter-group55/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=14-step=390.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=14-step=390.ckpt
/home/jupyter-group55/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-10 01:26:39.831[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m51[0m - [1mtest-encoder [0m
[32m2025-05-10 01:26:39.831[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mStarting fold 2 of 5[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=11-step=312.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=11-step=312.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-10 01:31:54.885[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m51[0m - [1mtest-encoder [0m
[32m2025-05-10 01:31:54.885[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mStarting fold 3 of 5[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=11-step=312-v1.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=11-step=312-v1.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-10 01:37:11.467[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m51[0m - [1mtest-encoder [0m
[32m2025-05-10 01:37:11.467[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mStarting fold 4 of 5[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=11-step=312-v2.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=11-step=312-v2.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | EEGSTFTSpikeClassifier | 824 K  | train
---------------------------------------------------------
824 K     Trainable params
0         Non-trainable params
824 K     Total params
3.299     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[32m2025-05-10 01:42:27.403[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m51[0m - [1mtest-encoder [0m
[32m2025-05-10 01:42:27.403[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mStarting fold 5 of 5[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=11-step=312-v3.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jupyter-group55/snn-encoder-test/models/test_encoder/checkpoints/epoch=11-step=312-v3.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-10 01:47:43.044[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m51[0m - [1mtest-encoder [0m


In [7]:
results_df = pd.DataFrame(fold_results)
results_df.set_index("fold", inplace=True)

results_df.to_csv(
    REPORTS_DIR / f"poisson_8_model_results_using_bsa.csv", index=True
)

params_file = REPORTS_DIR / f"poisson_8_model_params_using_bsa.json"

params_file.write_text(
    json.dumps(
        {
            "model_params": model_params,
            "optimizer_params": optimizer_params,
            "encoder_params": encoder_params,
        },
        indent=4,
    )
)
logger.info(f"Results and parameters saved to {REPORTS_DIR} for test-encoder")

[32m2025-05-10 01:47:43.063[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mResults and parameters saved to /home/jupyter-group55/snn-encoder-test/reports for test-encoder[0m
