In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-07 15:19:16.829[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /root/snn-encoder-test[0m


In [3]:
datamodule = CHBMITDataModule(dataset, batch_size=726, worker=0)

In [9]:
import os

import optuna

# Initialize the Optuna study
study = optuna.load_study(
    study_name="model-tuning-bsa",
    storage=os.environ["OPTUNA_CONN_STRING"],
)

In [5]:
complete_trial = study.get_trials(False, states=(optuna.trial.TrialState.COMPLETE,))
sorted_trials = sorted(complete_trial, key=lambda t: t.value)
filtered_trials = list(filter(lambda t: t.value < 15, sorted_trials))
filtered_trials

[FrozenTrial(number=1, state=1, values=[8.897629737854004], datetime_start=datetime.datetime(2025, 5, 7, 14, 46, 37, 985935), datetime_complete=datetime.datetime(2025, 5, 7, 14, 48, 14, 637899), params={'threshold': 0.05837571787591716, 'slope': 6.7082922718644795, 'beta': 0.6703615532512789, 'dropout_rate1': 0.3867650737818362, 'dropout_rate2': 0.2650897828875425, 'lr': 6.82947151662786e-05, 'weight_decay': 3.48210294726769e-06, 'scheduler_factor': 0.3434576235217472, 'scheduler_patience': 10, 'win_size': 3, 'cutoff': 0.7497922013805774, 'encoder_threshold': 0.015185576141913593}, user_attrs={}, system_attrs={'completed_rung_0': 14.532792091369629, 'completed_rung_1': 13.296854019165039, 'completed_rung_2': 10.877975463867188}, intermediate_values={0: 16.215518951416016, 1: 14.532792091369629, 2: 13.69881534576416, 3: 13.296854019165039, 4: 12.660313606262207, 5: 12.809510231018066, 6: 12.201603889465332, 7: 11.67796516418457, 8: 11.116118431091309, 9: 10.877975463867188, 10: 10.58395

In [7]:
from loguru import logger

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import BSAEncoder
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch
from IPython.display import clear_output

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    clear_output(wait=True)
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    max_window = trial.suggest_int("max_window", 4, 16)
    n_max = trial.suggest_int("n_max", 1, max_window)
    t_max = trial.suggest_int("t_max", 0, max_window // n_max)
    t_min = trial.suggest_int("t_min", 0, t_max)

    encoder_params = {
        "win_size": trial.suggest_int("win_size", 1, 16),
        "cutoff": trial.suggest_float("cutoff", 0.01, 1),
        "threshold": trial.suggest_float("encoder_threshold", 0.01, 1),
    }

    spike_encoder = BSAEncoder(**encoder_params)
    
    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=50,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss"), EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)

    val_loss = trainer.callback_metrics["val_loss"].item()
    
    trainer.test(lit_model, datamodule=datamodule)

    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"]
    test_f1 = trainer.callback_metrics["test_f1"]
    test_mse = trainer.callback_metrics["test_mse"]
    test_total_spikes = trainer.callback_metrics["test_total_spikes"]

    logger.info(f"Encoder: BSA Encoding,trial: {trial.number}, test_loss:{test_loss}, test_mse:{test_mse}, test_acc:{test_acc}, test_f1:{test_f1}, test_total_spikes:{test_total_spikes}")

    return val_loss

In [10]:
for trial in filtered_trials:
    study.enqueue_trial(params=trial.params)

study.optimize(objective, n_trials=len(filtered_trials))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-05-07 15:35:37,879] Trial 59 pruned. Trial was pruned at epoch 20.


In [12]:
logger.info(f"Encoder: BSA Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: BSA Encoding,trial, best_score: {study.best_value}")

[32m2025-05-07 15:35:50.339[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mEncoder: BSA Encoding,trial, best_param: {'threshold': 0.05837571787591716, 'slope': 6.7082922718644795, 'beta': 0.6703615532512789, 'dropout_rate1': 0.3867650737818362, 'dropout_rate2': 0.2650897828875425, 'lr': 6.82947151662786e-05, 'weight_decay': 3.48210294726769e-06, 'scheduler_factor': 0.3434576235217472, 'scheduler_patience': 10, 'max_window': 13, 'n_max': 9, 't_max': 0, 't_min': 0, 'win_size': 3, 'cutoff': 0.7497922013805774, 'encoder_threshold': 0.015185576141913593}[0m
[32m2025-05-07 15:35:50.775[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mEncoder: BSA Encoding,trial, best_score: 8.295167922973633[0m
