In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-06 17:31:23.924[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /root/snn-encoder-test[0m


In [3]:
datamodule = CHBMITDataModule(dataset, batch_size=300, worker=20)

In [4]:
import optuna
from loguru import logger
import gc

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import PhaseEncoderExpand
from eeg_snn_encoder.models.classifier import EEGSTFTSpikeClassifier, ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    encoder_params = {
        "phase_window": trial.suggest_int("phase_window", 1, 4),
    }

    spike_encoder = PhaseEncoderExpand(**encoder_params)
    
    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=30,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)

    val_loss = trainer.callback_metrics["val_loss"].item()
    
    trainer.test(lit_model, datamodule=datamodule)

    params = {
        "model_params": model_params,
        "optimizer_params": optimizer_params,
        "encoder_params": encoder_params
    }
    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"]
    test_f1 = trainer.callback_metrics["test_f1"]
    test_mse = trainer.callback_metrics["test_mse"]
    test_total_spikes = trainer.callback_metrics["test_total_spikes"]

    logger.info(f"Encoder: Phase Encoding,trial: {trial.number}, test_loss:{test_loss}, test_mse:{test_mse}, test_acc:{test_acc}, test_f1:{test_f1}, test_total_spikes:{test_total_spikes}")

    return val_loss

In [5]:
import os

# Initialize the Optuna study
sampler = optuna.samplers.TPESampler(multivariate=True, group=True)
pruner = optuna.pruners.HyperbandPruner()
study = optuna.create_study(
    direction="minimize",
    study_name="model-tuning-pe",
    storage=os.environ["OPTUNA_CONN_STRING"],
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner,
)
study.enqueue_trial({
    "threshold": 0.087, "slope": 10.5, "beta": 0.74,
    "dropout_rate1": 0.73, "dropout_rate2": 0.38,
    "lr": 3.2e-5, "weight_decay": 1.4e-5,
    "scheduler_factor": 0.22, "scheduler_patience": 6,
    "phase_window": 1,
})

# Execute the hyperparameter optimization trials.
# Note the addition of the `champion_callback` inclusion to control our logging
study.optimize(objective, n_trials=100)

logger.info(f"Encoder: Phase Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: Phase Encoding,trial, best_score: {study.study.best_value}")

[I 2025-05-06 17:31:31,855] A new study created in RDB with name: model-tuning-pe
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-06 17:34:46.400[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m73[0m - [1mEncoder: Phase Encoding,trial: 0, test_loss:10.09262466430664, test_mse:0.2154255360364914, test_acc:0.7845744490623474, test_f1:0.8007233738899231, test_total_spikes:8040768.0[0m


[I 2025-05-06 17:34:47,655] Trial 0 finished with value: 8.962407112121582 and parameters: {'threshold': 0.087, 'slope': 10.5, 'beta': 0.74, 'dropout_rate1': 0.73, 'dropout_rate2': 0.38, 'lr': 3.2e-05, 'weight_decay': 1.4e-05, 'scheduler_factor': 0.22, 'scheduler_patience': 6, 'phase_window': 1}. Best is trial 0 with value: 8.962407112121582.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-06 17:38:09.957[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m73[0m - [1mEncoder: Phase Encoding,trial: 1, test_loss:32.5, test_mse:0.519946813583374, test_acc:0.480053186416626, test_f1:0.0, test_total_spikes:8029000.0[0m


[I 2025-05-06 17:38:11,193] Trial 1 finished with value: 32.5 and parameters: {'threshold': 0.24827909612636825, 'slope': 7.461128070228822, 'beta': 0.20681230283272706, 'dropout_rate1': 0.9157200152181891, 'dropout_rate2': 0.9896738655948264, 'lr': 7.392690655972209e-06, 'weight_decay': 1.7230289275202613e-06, 'scheduler_factor': 0.8323937354903089, 'scheduler_patience': 9, 'phase_window': 1}. Best is trial 0 with value: 8.962407112121582.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-06 17:41:34.137[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m73[0m - [1mEncoder: Phase Encoding,trial: 2, test_loss:32.5, test_mse:0.48803192377090454, test_acc:0.5119680762290955, test_f1:0.0, test_total_spikes:8084091.0[0m


[I 2025-05-06 17:41:35,650] Trial 2 finished with value: 32.5 and parameters: {'threshold': 0.3275105791469493, 'slope': 18.25638595075481, 'beta': 0.32779030976279694, 'dropout_rate1': 0.29115557432899886, 'dropout_rate2': 0.711863362295149, 'lr': 1.9104746655058147e-06, 'weight_decay': 6.432062890112568e-05, 'scheduler_factor': 0.9360586840144578, 'scheduler_patience': 6, 'phase_window': 1}. Best is trial 0 with value: 8.962407112121582.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...
[W 2025-05-06 17:42:20,695] Trial 3 failed with parameters: {'threshold': 0.04656758720581827, 'slope': 18.26698352506645, 'beta': 0.41580025562393175, 'dropout_rate1': 0.22946306763812302, 'dropout_rate2': 0.4644468069453238, 'lr': 2.382435458861475e-06, 'weight_decay': 1.5837467350383952e-05, 'scheduler_factor': 0.7182383025607609, 'scheduler_patience': 1, 'phase_window': 4} because of the following error: NameError("name 'exit' is not defined").
Traceback (most recent call last):
  File "/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/root/snn-encoder-test/.venv/lib/python3.12/site-pack

NameError: name 'exit' is not defined