In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-08 17:59:19.574[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /workspace/snn-encoder-test[0m


In [3]:
datamodule = CHBMITDataModule(dataset, batch_size=128, worker=8)

In [4]:
be_studies = [
    {"name": "Classifier_Rate_Tuning", "score_limit": lambda x: x.value < 10},
    {"name": "model-tuning-poisson-mse", "score_limit": lambda x: x.value < 0.30},
]

In [5]:
import optuna
import os

test_trials = []

for i in be_studies:
    print(f"Loading good trial from {i["name"]}")
    old_study = optuna.load_study(
        study_name=i["name"],
        storage=os.environ["OPTUNA_CONN_STRING"],
    )

    complete_trial = old_study.get_trials(
        False, states=(optuna.trial.TrialState.COMPLETE,)
    )
    filtered_trials = list(filter(i["score_limit"], complete_trial))
    sorted_trials = sorted(filtered_trials, key=lambda t: t.value)

    test_trials += sorted_trials
    print(len(sorted_trials))

Loading good trial from Classifier_Rate_Tuning
8
Loading good trial from model-tuning-poisson-mse
6


In [6]:
study = optuna.create_study(
    direction="minimize",
    study_name="model-fine-tuning-poisson",
    storage=os.environ["OPTUNA_CONN_STRING"],
    load_if_exists=True,
    pruner=optuna.pruners.NopPruner()
)

[I 2025-05-08 17:59:37,423] A new study created in RDB with name: model-fine-tuning-poisson


In [7]:
from loguru import logger

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import PoissonEncoderExpand
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    encoder_params = {
        "interval_freq": trial.suggest_int("interval_freq", 1, 8),
        "random_seed": 47
    }

    spike_encoder = PoissonEncoderExpand(**encoder_params)

    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=50,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_mse"), EarlyStopping(monitor="val_loss", mode="min", patience=10)],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)

    return trainer.callback_metrics["val_mse"]

In [9]:
# for trial in test_trials:
#     study.enqueue_trial(params=trial.params)

# study.optimize(objective, n_trials=len(test_trials))

In [None]:
logger.info(f"Encoder: Burst Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: Burst Encoding,trial, best_score: {study.best_value}")

[32m2025-05-08 22:14:46.461[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mEncoder: Burst Encoding,trial, best_param: {'threshold': 0.07090851207541435, 'slope': 8.107275688930079, 'beta': 0.7440672774049281, 'dropout_rate1': 0.20034293181105056, 'dropout_rate2': 0.5865898197113356, 'lr': 9.362082441139168e-05, 'weight_decay': 1.3936923384285494e-06, 'scheduler_factor': 0.4595170872743502, 'scheduler_patience': 8, 'max_window': 5, 'n_max': 4, 't_max': 0, 't_min': 0}[0m
[32m2025-05-08 22:14:46.938[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mEncoder: Burst Encoding,trial, best_score: 0.16733068227767944[0m
