In [6]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [7]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

In [8]:
datamodule = CHBMITDataModule(dataset, batch_size=768, worker=0)

In [None]:
import optuna
from loguru import logger

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import BurstEncoder
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    max_window = trial.suggest_int("max_window", 4, 16)
    n_max = trial.suggest_int("n_max", 1, max_window)
    t_max = trial.suggest_int("t_max", 0, max_window // n_max)
    t_min = trial.suggest_int("t_min", 0, t_max)

    encoder_params = {
        "max_window": max_window,
        "n_max": n_max,
        "t_max": t_max,
        "t_min": t_min,
    }

    spike_encoder = BurstEncoder(**encoder_params)
    
    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=20,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss"), EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)

    val_loss = trainer.callback_metrics["val_loss"].item()
    
    trainer.test(lit_model, datamodule=datamodule)

    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"]
    test_f1 = trainer.callback_metrics["test_f1"]
    test_mse = trainer.callback_metrics["test_mse"]
    test_total_spikes = trainer.callback_metrics["test_total_spikes"]

    logger.info(f"Encoder: Burst Encoding,trial: {trial.number}, test_loss:{test_loss}, test_mse:{test_mse}, test_acc:{test_acc}, test_f1:{test_f1}, test_total_spikes:{test_total_spikes}")

    return val_loss

In [None]:
import os

# Initialize the Optuna study
sampler = optuna.samplers.TPESampler(multivariate=True)
study = optuna.create_study(
    direction="minimize",
    study_name="model-tuning-be",
    storage=os.environ["OPTUNA_CONN_STRING"],
    load_if_exists=True,
    sampler=sampler,
)

[I 2025-05-06 19:25:39,053] Using an existing study with name 'model-tuning-tbr' instead of creating a new one.


In [None]:
study.optimize(objective, n_trials=49)

In [None]:
logger.info(f"Encoder: Burst Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: Burst Encoding,trial, best_score: {study.best_value}")

[32m2025-05-06 19:26:44.905[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mEncoder: TBR Encoding,trial, best_param: {'threshold': 0.0305685971329455, 'slope': 8.82839671962939, 'beta': 0.8562602268238131, 'dropout_rate1': 0.1605071202767286, 'dropout_rate2': 0.3320934135789956, 'lr': 3.922547202908979e-05, 'weight_decay': 3.174904373757126e-05, 'scheduler_factor': 0.9103134810378268, 'scheduler_patience': 7, 'tbr_threshold': 2}[0m
[32m2025-05-06 19:26:45.352[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mEncoder: TBR Encoding,trial, best_score: 13.77458381652832[0m
