In [1]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-06 01:05:32.910[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: E:\Projects\snn-encoder-test[0m


In [2]:
import torch
from torch.utils.data import DataLoader, random_split

generator = torch.Generator().manual_seed(42)

train_dataset, val_dataset, test_dataset = random_split(dataset, [0.7, 0.1, 0.2], generator=generator)

train_loader = DataLoader(train_dataset, batch_size=32, num_workers=8, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=8, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=8, persistent_workers=True)

In [None]:
import mlflow
import os

mlflow_uri = os.environ["MLFLOW_TRACKING_URI"]
mlflow.set_tracking_uri(mlflow_uri)

In [None]:
import optuna

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders.global_temporal import PhaseEncoderExpand
from eeg_snn_encoder.models.classifier import EEGSTFTSpikeClassifier, ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig


def objective(trial: optuna.Trial) -> float:
    with mlflow.start_run(nested=True):
        model_params: ModelConfig = {
            "threshold": trial.suggest_float("threshold", 0.01, 0.5),
            "slope": trial.suggest_float("slope", 1.0, 20.0),
            "beta": trial.suggest_float("beta", 0.1, 0.99),
            "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
            "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
        }

        optimizer_params: OptimizerConfig = {
            "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
            "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
            "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
            "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
        }

        encoder_params = {
            "phase_window": trial.suggest_int("phase_window", 1, 8),
        }

        spike_encoder = PhaseEncoderExpand(**encoder_params)


        model = EEGSTFTSpikeClassifier(config=model_params)

        lit_model = LitSeizureClassifier(
            model=model,
            optimizer_config=optimizer_params,
            spike_encoder=spike_encoder,
        )

        trainer = pl.Trainer(
            max_epochs=15,
            accelerator="auto",
            devices="auto",
            strategy="auto",
            callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss"), EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        )

        trainer.fit(lit_model, train_loader, val_loader)
        trainer.test(lit_model, dataloaders=test_loader)

        params = {
            "model_params": model_params,
            "optimizer_params": optimizer_params,
            "encoder_params": encoder_params
        }
        test_loss = trainer.callback_metrics["test_loss"].item()
        test_acc = trainer.callback_metrics["test_acc"]
        test_f1 = trainer.callback_metrics["test_f1"]
        test_mse = trainer.callback_metrics["test_mse"]
        test_total_spikes = trainer.callback_metrics["test_total_spikes"]


        mlflow.log_params(params)
        mlflow.log_metric("test_mse_count_loss", test_loss)
        mlflow.log_metric("test_acc", test_acc)
        mlflow.log_metric("test_f1", test_f1)
        mlflow.log_metric("test_mse", test_mse)
        mlflow.log_metric("test_total_spikes", test_total_spikes)

    return test_loss

In [None]:
def get_or_create_experiment(experiment_name):
    if experiment := mlflow.get_experiment_by_name(experiment_name):
        return experiment.experiment_id
    else:
        return mlflow.create_experiment(experiment_name)

experiment_id = get_or_create_experiment("Model tuning SF")

In [None]:
run_name = "first_hyperopt_run"

In [None]:
with mlflow.start_run(experiment_id=experiment_id, run_name=run_name, nested=True):
    # Initialize the Optuna study
    study = optuna.create_study(
        direction="minimize",
        study_name="model-tuning-pe",
        storage=os.environ["OPTUNA_CONN_STRING"],
        load_if_exists=True,
        sampler=optuna.samplers.CmaEsSampler(),
        pruner=optuna.pruners.HyperbandPruner()
    )

    # Execute the hyperparameter optimization trials.
    # Note the addition of the `champion_callback` inclusion to control our logging
    study.optimize(objective, n_trials=50)

    mlflow.log_params(study.best_params)
    mlflow.log_metric("best_mse", study.best_value)

    mlflow.set_tags(
        tags={
            "project": "EEG SNN Encoder",
            "optimizer_engine": "optuna",
            "encoder_type": "phase-encoder",
            "model_type": "EEGSTFTSpikeClassifier",
        }
    )
    