In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

In [None]:
datamodule = CHBMITDataModule(dataset, batch_size=128, worker=0)

In [4]:
import optuna
from loguru import logger

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import PhaseEncoderExpand
from eeg_snn_encoder.callback import TrackBestMetric
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    encoder_params = {
        "phase_window": trial.suggest_int("phase_window", 1, 8),
    }

    spike_encoder = PhaseEncoderExpand(**encoder_params)
    
    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    metric_tracker = TrackBestMetric(
        monitor="val_mse",
        mode="min",
    )

    trainer = pl.Trainer(
        max_epochs=20,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[metric_tracker, PyTorchLightningPruningCallback(trial, monitor="val_mse"), EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)

    return metric_tracker.best_metric

In [5]:
import os

# Initialize the Optuna study
sampler = optuna.samplers.TPESampler(multivariate=True)
pruner = optuna.pruners.HyperbandPruner()
study = optuna.create_study(
    direction="minimize",
    study_name="model-tuning-pe-mse",
    storage=os.environ["OPTUNA_CONN_STRING"],
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner,
)

[I 2025-05-08 20:24:37,459] A new study created in RDB with name: model-tuning-pe-mse


In [6]:
study.optimize(objective, n_trials=50)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

New best value for val_mse: 0.4921875


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-05-08 20:26:53,551] Trial 0 finished with value: 0.4921875 and parameters: {'threshold': 0.4879322817798597, 'slope': 5.428118231898833, 'beta': 0.932198218747456, 'dropout_rate1': 0.8681331990121793, 'dropout_rate2': 0.5962990930658194, 'lr': 8.000097642494357e-06, 'weight_decay': 5.154598853290409e-05, 'scheduler_factor': 0.5919197004699149, 'scheduler_patience': 8, 'phase_window': 1}. Best is trial 0 with value: 0.4921875.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

New best value for val_mse: 0.4921875


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

New best value for val_mse: 0.48738381266593933


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-05-08 20:32:04,703] Trial 1 finished with value: 0.48738381266593933 and parameters: {'threshold': 0.33582774883549793, 'slope': 10.712773130240958, 'beta': 0.1386926578041482, 'dropout_rate1': 0.7679295510390521, 'dropout_rate2': 0.8473805946213264, 'lr': 8.473889237600508e-06, 'weight_decay': 1.6336202933051403e-05, 'scheduler_factor': 0.4420478490448556, 'scheduler_patience': 2, 'phase_window': 4}. Best is trial 1 with value: 0.48738381266593933.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

New best value for val_mse: 0.5234375


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...
[W 2025-05-08 20:32:41,819] Trial 2 failed with parameters: {'threshold': 0.4622270455430369, 'slope': 7.072840984269882, 'beta': 0.8125426041374195, 'dropout_rate1': 0.7358276034357659, 'dropout_rate2': 0.35161835295089383, 'lr': 2.6656344273999863e-06, 'weight_decay': 8.196445188452801e-06, 'scheduler_factor': 0.8926956556368142, 'scheduler_patience': 5, 'phase_window': 6} because of the following error: NameError("name 'exit' is not defined").
Traceback (most recent call last):
  File "/workspace/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/snn-encoder-test/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/workspace/snn-encoder-test/.venv/lib/python3

NameError: name 'exit' is not defined

In [None]:
logger.info(f"Encoder: Phase Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: Phase Encoding,trial, best_score: {study.best_value}")