In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset, CHBMITDataModule

# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

[32m2025-05-07 12:19:14.314[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /root/snn-encoder-test[0m


In [3]:
datamodule = CHBMITDataModule(dataset, batch_size=128, worker=20)

In [None]:
import optuna
from loguru import logger

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

from eeg_snn_encoder.encoders import PhaseEncoderExpand
from eeg_snn_encoder.models.classifier import ModelConfig
from eeg_snn_encoder.models.lightning import LitSeizureClassifier, OptimizerConfig
import torch

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }

    encoder_params = {
        "phase_window": trial.suggest_int("phase_window", 1, 8),
    }

    spike_encoder = PhaseEncoderExpand(**encoder_params)
    
    lit_model = LitSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=20,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_f1"), EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        logger=False,
    )

    trainer.fit(lit_model, datamodule=datamodule)
    trainer.test(lit_model, datamodule=datamodule)

    test_loss = trainer.callback_metrics["test_loss"].item()
    test_acc = trainer.callback_metrics["test_acc"]
    test_f1 = trainer.callback_metrics["test_f1"]
    test_mse = trainer.callback_metrics["test_mse"]
    test_total_spikes = trainer.callback_metrics["test_total_spikes"]

    logger.info(f"Encoder: Phase Encoding,trial: {trial.number}, test_loss:{test_loss}, test_mse:{test_mse}, test_acc:{test_acc}, test_f1:{test_f1}, test_total_spikes:{test_total_spikes}")

    return test_f1

In [None]:
import os

# Initialize the Optuna study
sampler = optuna.samplers.TPESampler(multivariate=True)
pruner = optuna.pruners.HyperbandPruner()
study = optuna.create_study(
    direction="maximize",
    study_name="model-tuning-pe-f1",
    storage=os.environ["OPTUNA_CONN_STRING"],
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner,
)



KeyboardInterrupt: 

: 

In [None]:
study.optimize(objective, n_trials=50)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 12:09:53.591[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m67[0m - [1mEncoder: Phase Encoding,trial: 5, test_loss:97.5, test_mse:0.4614361822605133, test_acc:0.5385638475418091, test_f1:0.0, test_total_spikes:148475.328125[0m


[I 2025-05-07 12:09:55,011] Trial 5 finished with value: 97.5 and parameters: {'threshold': 0.2907938153998132, 'slope': 11.00900911242831, 'beta': 0.6283036379272648, 'dropout_rate1': 0.9411066516471832, 'dropout_rate2': 0.13476081249379904, 'lr': 1.0115449352715596e-06, 'weight_decay': 1.3796108704998097e-06, 'scheduler_factor': 0.4589527231754994, 'scheduler_patience': 4, 'phase_window': 3}. Best is trial 0 with value: 8.962407112121582.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-07 12:12:06.830[0m | [1mINFO    [0m | [36m__main__[0m:[36mobjective[0m:[36m67[0m - [1mEncoder: Phase Encoding,trial: 6, test_loss:97.5, test_mse:0.49468085169792175, test_acc:0.5053191781044006, test_f1:0.0, test_total_spikes:149140.046875[0m


[I 2025-05-07 12:12:08,259] Trial 6 finished with value: 97.5 and parameters: {'threshold': 0.28493413511552995, 'slope': 13.812801226092452, 'beta': 0.6589009609228247, 'dropout_rate1': 0.5762691651639437, 'dropout_rate2': 0.6894789883364355, 'lr': 3.27528709884739e-06, 'weight_decay': 8.073315501439542e-05, 'scheduler_factor': 0.16974811578745763, 'scheduler_patience': 6, 'phase_window': 3}. Best is trial 0 with value: 8.962407112121582.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

[W 2025-05-07 12:12:19,064] Trial 7 failed with parameters: {'threshold': 0.026690653533421864, 'slope': 3.2078251419427652, 'beta': 0.24752014192501334, 'dropout_rate1': 0.6348040254741011, 'dropout_rate2': 0.5444378583498093, 'lr': 3.095002648700594e-05, 'weight_decay': 7.539585894032446e-05, 'scheduler_factor': 0.4836430382864343, 'scheduler_patience': 10, 'phase_window': 5} because of the following error: OutOfMemoryError('CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 20.12 MiB is free. Process 2867789 has 1.69 GiB memory in use. Process 2888004 has 21.92 GiB memory in use. Of the allocated memory 20.28 GiB is allocated by PyTorch, and 1.18 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)').
Trace

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 20.12 MiB is free. Process 2867789 has 1.69 GiB memory in use. Process 2888004 has 21.92 GiB memory in use. Of the allocated memory 20.28 GiB is allocated by PyTorch, and 1.18 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
logger.info(f"Encoder: Phase Encoding,trial, best_param: {study.best_params}")
logger.info(f"Encoder: Phase Encoding,trial, best_score: {study.best_value}")