In [1]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-06 01:05:32.910[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: E:\Projects\snn-encoder-test[0m


In [2]:
import torch
from torch.utils.data import DataLoader, random_split

generator = torch.Generator().manual_seed(42)

train_dataset, val_dataset, test_dataset = random_split(dataset, [0.7, 0.1, 0.2], generator=generator)

train_loader = DataLoader(train_dataset, batch_size=32, num_workers=8, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=8, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=8, persistent_workers=True)

In [None]:
import optuna

from optuna.integration import PyTorchLightningPruningCallback
import pytorch_lightning as pl

from eeg_snn_encoder.encoders.global_temporal import PhaseEncoderExpand
from eeg_snn_encoder.models.classifier import EEGSTFTSpikeClassifier, ModelConfig
from eeg_snn_encoder.models.lightning import LitEvalSeizureClassifier, OptimizerConfig


def objective(trial: optuna.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    encoder_params = {
        "phase_window": trial.suggest_int("phase_window", 1, 8),
        "normalize": False
    }

    spike_encoder = PhaseEncoderExpand(**encoder_params)


    model = EEGSTFTSpikeClassifier(config=model_params)

    lit_model = LitEvalSeizureClassifier(
        model=model,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=15,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
    )

    trainer.fit(lit_model, train_loader, val_loader)
    
    return trainer.callback_metrics["val_loss"].item()