In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-07 16:16:51.848[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /root/snn-encoder-test[0m


In [3]:
datamodule = CHBMITDataModule(dataset, batch_size=256, worker=0)

In [None]:
import optuna
from loguru import logger

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import PoissonEncoderExpand
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    encoder_params = {
        "interval_freq": trial.suggest_int("interval_freq", 1, 8),
        "random_seed": 47
    }

    spike_encoder = PoissonEncoderExpand(**encoder_params)
    
    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=20,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_f1"), EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)
    trainer.test(lit_model, datamodule=datamodule)

    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"]
    test_f1 = trainer.callback_metrics["test_f1"]
    test_mse = trainer.callback_metrics["test_mse"]
    test_total_spikes = trainer.callback_metrics["test_total_spikes"]

    logger.info(f"Encoder: Poisson Encoding,trial: {trial.number}, test_loss:{test_loss}, test_mse:{test_mse}, test_acc:{test_acc}, test_f1:{test_f1}, test_total_spikes:{test_total_spikes}")

    return test_f1

In [None]:
import os

# Initialize the Optuna study
sampler = optuna.samplers.TPESampler(multivariate=True)
pruner = optuna.pruners.HyperbandPruner()
study = optuna.create_study(
    direction="maximize",
    study_name="model-tuning-poisson-f1",
    storage=os.environ["OPTUNA_CONN_STRING"],
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner,
)



[I 2025-05-07 16:16:59,453] Using an existing study with name 'model-tuning-poisson' instead of creating a new one.


In [6]:
study.optimize(objective, n_trials=50)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 16:21:09.948[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m68[0m - [1mEncoder: Poisson Encoding,trial: 7, test_loss:19.406694412231445, test_mse:0.18617020547389984, test_acc:0.813829779624939, test_f1:0.8114399909973145, test_total_spikes:91312.7890625[0m


[I 2025-05-07 16:21:11,296] Trial 7 finished with value: 20.61240577697754 and parameters: {'threshold': 0.06560935122788826, 'slope': 19.515178794292677, 'beta': 0.7485738238159846, 'dropout_rate1': 0.41280634724450427, 'dropout_rate2': 0.7297685730505821, 'lr': 3.9738292608377826e-05, 'weight_decay': 1.95489622971466e-05, 'scheduler_factor': 0.4689936503974681, 'scheduler_patience': 8, 'interval_freq': 2}. Best is trial 7 with value: 20.61240577697754.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 16:23:07.050[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m68[0m - [1mEncoder: Poisson Encoding,trial: 8, test_loss:97.5, test_mse:0.5066489577293396, test_acc:0.4933510720729828, test_f1:0.0, test_total_spikes:136286.921875[0m


[I 2025-05-07 16:23:08,176] Trial 8 finished with value: 97.5 and parameters: {'threshold': 0.13543629133600155, 'slope': 1.4562157287618578, 'beta': 0.1878660998154414, 'dropout_rate1': 0.36738842747154665, 'dropout_rate2': 0.6703615532512789, 'lr': 4.4097695688757595e-06, 'weight_decay': 2.3495693929948604e-06, 'scheduler_factor': 0.916302258643388, 'scheduler_patience': 3, 'interval_freq': 3}. Best is trial 7 with value: 20.61240577697754.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 16:24:28.622[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m68[0m - [1mEncoder: Poisson Encoding,trial: 9, test_loss:65.0, test_mse:0.4720744788646698, test_acc:0.5279255509376526, test_f1:0.0, test_total_spikes:91505.2421875[0m


[I 2025-05-07 16:24:29,763] Trial 9 finished with value: 65.0 and parameters: {'threshold': 0.47766221437528594, 'slope': 3.415176850122207, 'beta': 0.7650657163926403, 'dropout_rate1': 0.10466178057202334, 'dropout_rate2': 0.862543640999285, 'lr': 2.465541986353358e-05, 'weight_decay': 1.2765899128472682e-05, 'scheduler_factor': 0.9323598597302472, 'scheduler_patience': 6, 'interval_freq': 2}. Best is trial 7 with value: 20.61240577697754.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 16:27:30.141[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m68[0m - [1mEncoder: Poisson Encoding,trial: 10, test_loss:162.5, test_mse:0.5146276354789734, test_acc:0.4853723347187042, test_f1:0.0, test_total_spikes:229140.984375[0m


[I 2025-05-07 16:27:31,471] Trial 10 finished with value: 162.5 and parameters: {'threshold': 0.27306423830758164, 'slope': 6.575736541618038, 'beta': 0.10943582708158718, 'dropout_rate1': 0.886582638835099, 'dropout_rate2': 0.684205476829867, 'lr': 7.664996224886574e-05, 'weight_decay': 3.089517044723675e-05, 'scheduler_factor': 0.3378172967268379, 'scheduler_patience': 4, 'interval_freq': 5}. Best is trial 7 with value: 20.61240577697754.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...
[W 2025-05-07 16:28:12,202] Trial 11 failed with parameters: {'threshold': 0.2779736358200623, 'slope': 5.914736155688917, 'beta': 0.2554263490552968, 'dropout_rate1': 0.4210321291762418, 'dropout_rate2': 0.22476221609992833, 'lr': 6.00008372146847e-06, 'weight_decay': 8.785706042667946e-06, 'scheduler_factor': 0.9622559614920545, 'scheduler_patience': 2, 'interval_freq': 5} because of the following error: NameError("name 'exit' is not defined").
Traceback (most recent call last):
  File "/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/root/snn-encoder-test/.venv/lib/python3.12/site-packag

NameError: name 'exit' is not defined

In [None]:
logger.info(f"Encoder: Poisson Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: Poisson Encoding,trial, best_score: {study.best_value}")

[32m2025-05-07 15:17:07.216[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mEncoder: BSA Encoding,trial, best_param: {'threshold': 0.05837571787591716, 'slope': 6.7082922718644795, 'beta': 0.6703615532512789, 'dropout_rate1': 0.3867650737818362, 'dropout_rate2': 0.2650897828875425, 'lr': 6.82947151662786e-05, 'weight_decay': 3.48210294726769e-06, 'scheduler_factor': 0.3434576235217472, 'scheduler_patience': 10, 'win_size': 3, 'cutoff': 0.7497922013805774, 'encoder_threshold': 0.015185576141913593}[0m
[32m2025-05-07 15:17:07.640[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mEncoder: BSA Encoding,trial, best_score: 8.897629737854004[0m
