# Tuning the model


## Load the data
Load the preprocessed data which prepared in HDF5 format. The data is split into training, validation, and test sets. The training set is used to train the model, the validation set is used to tune the hyperparameters, and the test set is used to evaluate the final model performance.

In [1]:
from dataset import CHBMITPreprocessedDataset

dataset = CHBMITPreprocessedDataset('./CHB-MIT/processed_data.h5')

## Assemble the model

We will use the Pytorch Lightning library to build the model pipeline from the based model.

In [2]:
from dataset.lightning import CHBMITPreprocessedDataModule

datamodule = CHBMITPreprocessedDataModule(dataset, batch_size=32)

In [3]:
import optuna
from encoder import TBREncoder
from optuna.integration import PyTorchLightningPruningCallback
from model.classifier import ModelConfig
from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
import lightning.pytorch as pl


def objective_tbr(trial: optuna.trial.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }
    encoder_threshold = trial.suggest_float("encoder_threshold", 0.1, 0.99, log=True)
    
    spike_encoder = TBREncoder(threshold=encoder_threshold, normalize=False)

    model = LitSTFTPreprocessedSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=15,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
    )

    trainer.fit(model, datamodule=datamodule)

    return trainer.callback_metrics["val_loss"].item()

In [4]:
# import optuna
# from encoder import PoissonEncoder
# from optuna.integration import PyTorchLightningPruningCallback
# from model.classifier import ModelConfig
# from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
# import lightning.pytorch as pl


# def objective_rate(trial: optuna.trial.Trial) -> float:
#     model_params: ModelConfig = {
#         "threshold": trial.suggest_float("threshold", 0.01, 0.5),
#         "slope": trial.suggest_float("slope", 1.0, 20.0),
#         "beta": trial.suggest_float("beta", 0.1, 0.99),
#         "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
#         "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
#     }

#     optimizer_params: OptimizerConfig = {
#         "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
#         "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
#         "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
#         "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
#     }
#     params = {
#         "interval_freq": trial.suggest_int("interval_freq", 1, 16),
#         "normalize": False
#     }

#     spike_encoder = PoissonEncoder(**params)

#     model = LitSTFTPreprocessedSeizureClassifier(
#         model_config=model_params,
#         optimizer_config=optimizer_params,
#         spike_encoder=spike_encoder,
#     )

#     trainer = pl.Trainer(
#         max_epochs=15,
#         accelerator="auto",
#         devices="auto",
#         strategy="auto",
#         callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
#     )

#     trainer.fit(model, datamodule=datamodule)

#     return trainer.callback_metrics["val_loss"].item()

In [5]:
import optuna
from encoder import BurstEncoder
from optuna.integration import PyTorchLightningPruningCallback
from model.classifier import ModelConfig
from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
import lightning.pytorch as pl


def objective_be(trial: optuna.trial.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }
    t_min = trial.suggest_int("t_min", 1, 16)
    t_max = trial.suggest_int("t_max", t_min, 16)

    params = {
        "max_window": trial.suggest_int("max_window", 1, 16),
        "n_max": trial.suggest_int("n_max", 1, 16),
        "t_min": t_min,
        "t_max": t_max,
        "normalize": False
    }


    spike_encoder = BurstEncoder(**params)

    model = LitSTFTPreprocessedSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=15,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
    )

    trainer.fit(model, datamodule=datamodule)

    return trainer.callback_metrics["val_loss"].item()

In [6]:
# import optuna
# from encoder import DummyEncoder
# from optuna.integration import PyTorchLightningPruningCallback
# from model.classifier import ModelConfig
# from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
# import lightning.pytorch as pl


# def objective_base(trial: optuna.trial.Trial) -> float:
#     model_params: ModelConfig = {
#         "threshold": trial.suggest_float("threshold", 0.01, 0.5),
#         "slope": trial.suggest_float("slope", 1.0, 20.0),
#         "beta": trial.suggest_float("beta", 0.1, 0.99),
#         "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
#         "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
#     }

#     optimizer_params: OptimizerConfig = {
#         "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
#         "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
#         "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
#         "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
#     }

#     spike_encoder = DummyEncoder()

#     model = LitSTFTPreprocessedSeizureClassifier(
#         model_config=model_params,
#         optimizer_config=optimizer_params,
#         spike_encoder=spike_encoder,
#     )

#     trainer = pl.Trainer(
#         max_epochs=15,
#         accelerator="auto",
#         devices="auto",
#         strategy="auto",
#         callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
#     )

#     trainer.fit(model, datamodule=datamodule)

#     return trainer.callback_metrics["val_loss"].item()

In [7]:
from config import DB_CONN_STRING
import optuna

def create_and_run_study(study_name: str, objective, n_trials: int):

    sampler = optuna.samplers.CmaEsSampler()
    study = optuna.create_study(
        direction="minimize",
        study_name=study_name,
        sampler=sampler,
        storage=DB_CONN_STRING,
        load_if_exists=True,
        pruner=optuna.pruners.HyperbandPruner()
    )
    study.optimize(objective, n_trials=n_trials)
    return study

In [8]:
experiment = [
    # ("Classifier_TBR_Tuning", objective_tbr, 100),
    ("Classifier_BE_Tuning", objective_be, 100)
]

In [9]:
for name, objective_fn, n_trials in experiment:
    create_and_run_study(name, objective_fn, n_trials)

[I 2025-05-02 14:21:09,327] A new study created in RDB with name: Classifier_BE_Tuning
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type               | Params | Mode 
-----------------------------------------------------
0 | model | EEGSpikeClassifier | 824 K  | train
-----------------------------------------------------
824 K     Trainable params
0         Non-trainable pa

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

[W 2025-05-02 14:21:10,282] Trial 0 failed with parameters: {'threshold': 0.43167924203892594, 'slope': 3.3384099767505147, 'beta': 0.9742631434388955, 'dropout_rate1': 0.653976385990211, 'dropout_rate2': 0.6874588543612099, 'lr': 3.4801008169401985e-06, 'weight_decay': 2.462070744524151e-05, 'scheduler_factor': 0.166906470563762, 'scheduler_patience': 2, 't_min': 5, 't_max': 11, 'max_window': 11, 'n_max': 11} because of the following error: OutOfMemoryError('CUDA out of memory. Tried to allocate 24.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 2.38 MiB is free. Process 3098062 has 3.52 GiB memory in use. Process 3427738 has 5.06 GiB memory in use. Process 3383238 has 11.74 GiB memory in use. Process 3427694 has 2.81 GiB memory in use. Including non-PyTorch memory, this process has 492.00 MiB memory in use. Of the allocated memory 25.67 MiB is allocated by PyTorch, and 20.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting P

OutOfMemoryError: CUDA out of memory. Tried to allocate 24.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 2.38 MiB is free. Process 3098062 has 3.52 GiB memory in use. Process 3427738 has 5.06 GiB memory in use. Process 3383238 has 11.74 GiB memory in use. Process 3427694 has 2.81 GiB memory in use. Including non-PyTorch memory, this process has 492.00 MiB memory in use. Of the allocated memory 25.67 MiB is allocated by PyTorch, and 20.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)