# Tuning the model


## Load the data
Load the preprocessed data which prepared in HDF5 format. The data is split into training, validation, and test sets. The training set is used to train the model, the validation set is used to tune the hyperparameters, and the test set is used to evaluate the final model performance.

In [1]:
from dataset import CHBMITPreprocessedDataset

dataset = CHBMITPreprocessedDataset('./CHB-MIT/processed_data.h5')

## Assemble the model

We will use the Pytorch Lightning library to build the model pipeline from the based model.

In [2]:
from dataset.lightning import CHBMITPreprocessedDataModule

datamodule = CHBMITPreprocessedDataModule(dataset, batch_size=32)

In [3]:
import optuna
from encoder import TBREncoder
from optuna.integration import PyTorchLightningPruningCallback
from model.classifier import ModelConfig
from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
import lightning.pytorch as pl


def objective_tbr(trial: optuna.trial.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }
    encoder_threshold = trial.suggest_float("encoder_threshold", 0.1, 0.99, log=True)
    
    spike_encoder = TBREncoder(threshold=encoder_threshold, normalize=False)

    model = LitSTFTPreprocessedSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=15,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
    )

    trainer.fit(model, datamodule=datamodule)

    return trainer.callback_metrics["val_loss"].item()

In [4]:
# import optuna
# from encoder import PoissonEncoder
# from optuna.integration import PyTorchLightningPruningCallback
# from model.classifier import ModelConfig
# from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
# import lightning.pytorch as pl


# def objective_rate(trial: optuna.trial.Trial) -> float:
#     model_params: ModelConfig = {
#         "threshold": trial.suggest_float("threshold", 0.01, 0.5),
#         "slope": trial.suggest_float("slope", 1.0, 20.0),
#         "beta": trial.suggest_float("beta", 0.1, 0.99),
#         "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
#         "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
#     }

#     optimizer_params: OptimizerConfig = {
#         "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
#         "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
#         "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
#         "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
#     }
#     params = {
#         "interval_freq": trial.suggest_int("interval_freq", 1, 16),
#         "normalize": False
#     }

#     spike_encoder = PoissonEncoder(**params)

#     model = LitSTFTPreprocessedSeizureClassifier(
#         model_config=model_params,
#         optimizer_config=optimizer_params,
#         spike_encoder=spike_encoder,
#     )

#     trainer = pl.Trainer(
#         max_epochs=15,
#         accelerator="auto",
#         devices="auto",
#         strategy="auto",
#         callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
#     )

#     trainer.fit(model, datamodule=datamodule)

#     return trainer.callback_metrics["val_loss"].item()

In [5]:
import optuna
from encoder import BurstEncoder
from optuna.integration import PyTorchLightningPruningCallback
from model.classifier import ModelConfig
from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
import lightning.pytorch as pl


def objective_be(trial: optuna.trial.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }
    t_min = trial.suggest_int("t_min", 1, 16)
    t_max = trial.suggest_int("t_max", t_min, 16)

    params = {
        "max_window": trial.suggest_int("max_window", 1, 16),
        "n_max": trial.suggest_int("n_max", 1, 16),
        "t_min": t_min,
        "t_max": t_max,
        "normalize": False
    }


    spike_encoder = BurstEncoder(**params)

    model = LitSTFTPreprocessedSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=15,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
    )

    trainer.fit(model, datamodule=datamodule)

    return trainer.callback_metrics["val_loss"].item()

In [6]:
# import optuna
# from encoder import DummyEncoder
# from optuna.integration import PyTorchLightningPruningCallback
# from model.classifier import ModelConfig
# from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
# import lightning.pytorch as pl


# def objective_base(trial: optuna.trial.Trial) -> float:
#     model_params: ModelConfig = {
#         "threshold": trial.suggest_float("threshold", 0.01, 0.5),
#         "slope": trial.suggest_float("slope", 1.0, 20.0),
#         "beta": trial.suggest_float("beta", 0.1, 0.99),
#         "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
#         "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
#     }

#     optimizer_params: OptimizerConfig = {
#         "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
#         "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
#         "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
#         "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
#     }

#     spike_encoder = DummyEncoder()

#     model = LitSTFTPreprocessedSeizureClassifier(
#         model_config=model_params,
#         optimizer_config=optimizer_params,
#         spike_encoder=spike_encoder,
#     )

#     trainer = pl.Trainer(
#         max_epochs=15,
#         accelerator="auto",
#         devices="auto",
#         strategy="auto",
#         callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
#     )

#     trainer.fit(model, datamodule=datamodule)

#     return trainer.callback_metrics["val_loss"].item()

In [7]:
from config import DB_CONN_STRING
import optuna

def create_and_run_study(study_name: str, objective, n_trials: int):

    sampler = optuna.samplers.CmaEsSampler()
    study = optuna.create_study(
        direction="minimize",
        study_name=study_name,
        sampler=sampler,
        storage=DB_CONN_STRING,
        load_if_exists=True,
        pruner=optuna.pruners.HyperbandPruner()
    )
    study.optimize(objective, n_trials=n_trials)
    return study

In [8]:
experiment = [
    ("Classifier_TBR_Tuning", objective_tbr, 100)
    # ("Classifier_BE_Tuning", objective_be, 100)
]

In [9]:
for name, objective_fn, n_trials in experiment:
    create_and_run_study(name, objective_fn, n_trials)

[I 2025-05-02 14:21:22,086] Using an existing study with name 'Classifier_TBR_Tuning' instead of creating a new one.
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[W 2025-05-02 14:21:22,533] Trial 8 failed with parameters: {'threshold': 0.24047997338687024, 'slope': 10.527643800292791, 'beta': 0.37752429156045253, 'dropout_rate1': 0.9104056643270606, 'dropout_rate2': 0.863434987819242, 'l

RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
