In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-07 08:17:29.303[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /root/snn-encoder-test[0m


In [3]:
datamodule = CHBMITDataModule(dataset, batch_size=512, worker=0)

In [4]:
import os

import optuna

# Initialize the Optuna study
sampler = optuna.samplers.TPESampler(multivariate=True)
study = optuna.load_study(
    study_name="model-tuning-be",
    storage=os.environ["OPTUNA_CONN_STRING"],
    sampler=sampler,
)



In [5]:
complete_trial = study.get_trials(False, states=(optuna.trial.TrialState.COMPLETE,))
sorted_trials = sorted(complete_trial, key=lambda t: t.value)
filtered_trials = list(filter(lambda t: t.value < 15, sorted_trials))
filtered_trials

[FrozenTrial(number=44, state=1, values=[10.841240882873535], datetime_start=datetime.datetime(2025, 5, 6, 20, 20, 45, 275730), datetime_complete=datetime.datetime(2025, 5, 6, 20, 22, 42, 907178), params={'threshold': 0.07090851207541435, 'slope': 8.107275688930079, 'beta': 0.7440672774049281, 'dropout_rate1': 0.20034293181105056, 'dropout_rate2': 0.5865898197113356, 'lr': 9.362082441139168e-05, 'weight_decay': 1.3936923384285494e-06, 'scheduler_factor': 0.4595170872743502, 'scheduler_patience': 8, 'max_window': 5, 'n_max': 4, 't_max': 0, 't_min': 0}, user_attrs={}, system_attrs={}, intermediate_values={0: 17.381244659423828, 1: 17.244325637817383, 2: 19.790807723999023, 3: 18.930034637451172, 4: 15.946797370910645, 5: 14.786341667175293, 6: 14.04427433013916, 7: 13.216415405273438, 8: 12.813944816589355, 9: 13.05215072631836, 10: 12.538644790649414, 11: 12.573296546936035, 12: 12.258035659790039, 13: 12.028778076171875, 14: 11.209694862365723, 15: 10.950955390930176, 16: 10.6388492584

In [6]:
from loguru import logger

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import BurstEncoder
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    max_window = trial.suggest_int("max_window", 4, 16)
    n_max = trial.suggest_int("n_max", 1, max_window)
    t_max = trial.suggest_int("t_max", 0, max_window // n_max)
    t_min = trial.suggest_int("t_min", 0, t_max)

    encoder_params = {
        "max_window": max_window,
        "n_max": n_max,
        "t_max": t_max,
        "t_min": t_min,
    }

    spike_encoder = BurstEncoder(**encoder_params)
    
    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=50,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss"), EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)

    val_loss = trainer.callback_metrics["val_loss"].item()
    
    trainer.test(lit_model, datamodule=datamodule)

    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"]
    test_f1 = trainer.callback_metrics["test_f1"]
    test_mse = trainer.callback_metrics["test_mse"]
    test_total_spikes = trainer.callback_metrics["test_total_spikes"]

    logger.info(f"Encoder: Burst Encoding,trial: {trial.number}, test_loss:{test_loss}, test_mse:{test_mse}, test_acc:{test_acc}, test_f1:{test_f1}, test_total_spikes:{test_total_spikes}")

    return val_loss

In [7]:
for trial in filtered_trials:
    study.enqueue_trial(params=trial.params)

study.optimize(objective, n_trials=len(filtered_trials))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 08:26:35.666[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m74[0m - [1mEncoder: Burst Encoding,trial: 53, test_loss:9.631075859069824, test_mse:0.20079787075519562, test_acc:0.7992021441459656, test_f1:0.7941028475761414, test_total_spikes:25019882.0[0m


[I 2025-05-07 08:26:36,998] Trial 53 finished with value: 10.168556213378906 and parameters: {'threshold': 0.07090851207541435, 'slope': 8.107275688930079, 'beta': 0.7440672774049281, 'dropout_rate1': 0.20034293181105056, 'dropout_rate2': 0.5865898197113356, 'lr': 9.362082441139168e-05, 'weight_decay': 1.3936923384285494e-06, 'scheduler_factor': 0.4595170872743502, 'scheduler_patience': 8, 'max_window': 5, 'n_max': 4, 't_max': 0, 't_min': 0}. Best is trial 53 with value: 10.168556213378906.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-05-07 08:26:48,695] Trial 54 pruned. Trial was pruned at epoch 0.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 08:29:24.163[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m74[0m - [1mEncoder: Burst Encoding,trial: 55, test_loss:12.105426788330078, test_mse:0.26196807622909546, test_acc:0.7380319237709045, test_f1:0.7806776165962219, test_total_spikes:20279166.0[0m


[I 2025-05-07 08:29:25,477] Trial 55 finished with value: 12.73603630065918 and parameters: {'threshold': 0.0398863898955059, 'slope': 18.522680857220692, 'beta': 0.7809379322752477, 'dropout_rate1': 0.5100500559710461, 'dropout_rate2': 0.730158315841459, 'lr': 9.379778169610137e-05, 'weight_decay': 4.3660072956328376e-05, 'scheduler_factor': 0.6932955598656476, 'scheduler_patience': 6, 'max_window': 11, 'n_max': 9, 't_max': 0, 't_min': 0}. Best is trial 53 with value: 10.168556213378906.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-05-07 08:30:13,608] Trial 56 pruned. Trial was pruned at epoch 6.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-05-07 08:31:09,312] Trial 57 pruned. Trial was pruned at epoch 7.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-05-07 08:31:20,049] Trial 58 pruned. Trial was pruned at epoch 0.


In [8]:
logger.info(f"Encoder: Burst Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: Burst Encoding,trial, best_score: {study.best_value}")

[32m2025-05-06 20:36:15.798[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mEncoder: Burst Encoding,trial, best_param: {'threshold': 0.07090851207541435, 'slope': 8.107275688930079, 'beta': 0.7440672774049281, 'dropout_rate1': 0.20034293181105056, 'dropout_rate2': 0.5865898197113356, 'lr': 9.362082441139168e-05, 'weight_decay': 1.3936923384285494e-06, 'scheduler_factor': 0.4595170872743502, 'scheduler_patience': 8, 'max_window': 5, 'n_max': 4, 't_max': 0, 't_min': 0}[0m
[32m2025-05-06 20:36:16.180[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mEncoder: Burst Encoding,trial, best_score: 10.841240882873535[0m
