In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-07 18:17:33.404[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /root/snn-encoder-test[0m


In [3]:
datamodule = CHBMITDataModule(dataset, batch_size=726, worker=0)

In [4]:
import optuna
from loguru import logger

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import BSAEncoder
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch
from IPython.display import clear_output

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    clear_output(wait=True)
    
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    encoder_params = {
        "win_size": trial.suggest_int("win_size", 1, 16),
        "cutoff": trial.suggest_float("cutoff", 0.01, 1),
        "threshold": trial.suggest_float("encoder_threshold", 0.01, 1),
    }

    spike_encoder = BSAEncoder(**encoder_params)
    
    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=20,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_f1"), EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)
    trainer.test(lit_model, datamodule=datamodule)

    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"]
    test_f1 = trainer.callback_metrics["test_f1"]
    test_mse = trainer.callback_metrics["test_mse"]
    test_total_spikes = trainer.callback_metrics["test_total_spikes"]

    logger.info(f"Encoder: BSA Encoding,trial: {trial.number}, test_loss:{test_loss}, test_mse:{test_mse}, test_acc:{test_acc}, test_f1:{test_f1}, test_total_spikes:{test_total_spikes}")

    return test_f1

In [5]:
import os

# Initialize the Optuna study
sampler = optuna.samplers.TPESampler(multivariate=True, seed=47)
pruner = optuna.pruners.HyperbandPruner()
study = optuna.create_study(
    direction="maximize",
    study_name="model-tuning-bsa-f1",
    storage=os.environ["OPTUNA_CONN_STRING"],
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner
)

[I 2025-05-07 18:17:41,002] A new study created in RDB with name: model-tuning-bsa-f1


In [None]:
study.optimize(objective, n_trials=50)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
logger.info(f"Encoder: BSA Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: BSA Encoding,trial, best_score: {study.best_value}")

[32m2025-05-07 15:17:07.216[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mEncoder: BSA Encoding,trial, best_param: {'threshold': 0.05837571787591716, 'slope': 6.7082922718644795, 'beta': 0.6703615532512789, 'dropout_rate1': 0.3867650737818362, 'dropout_rate2': 0.2650897828875425, 'lr': 6.82947151662786e-05, 'weight_decay': 3.48210294726769e-06, 'scheduler_factor': 0.3434576235217472, 'scheduler_patience': 10, 'win_size': 3, 'cutoff': 0.7497922013805774, 'encoder_threshold': 0.015185576141913593}[0m
[32m2025-05-07 15:17:07.640[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mEncoder: BSA Encoding,trial, best_score: 8.897629737854004[0m
