# Tuning the model


## Load the data
Load the preprocessed data which prepared in HDF5 format. The data is split into training, validation, and test sets. The training set is used to train the model, the validation set is used to tune the hyperparameters, and the test set is used to evaluate the final model performance.

In [1]:
from dataset import CHBMITPreprocessedDataset

dataset = CHBMITPreprocessedDataset('./CHB-MIT/processed_data.h5')

## Assemble the model

We will use the Pytorch Lightning library to build the model pipeline from the based model.

In [2]:
from dataset.lightning import CHBMITPreprocessedDataModule

datamodule = CHBMITPreprocessedDataModule(dataset, batch_size=32)

In [7]:
import optuna
from encoder import StepForwardEncoder
from optuna.integration import PyTorchLightningPruningCallback
from model.classifier import ModelConfig
from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
import lightning.pytorch as pl


def objective_sf(trial: optuna.trial.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }
    encoder_threshold = trial.suggest_float("encoder_threshold", 0.1, 0.99, log=True)
    
    spike_encoder = StepForwardEncoder(threshold=encoder_threshold, normalize=False)

    model = LitSTFTPreprocessedSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=15,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
    )

    trainer.fit(model, datamodule=datamodule)

    return trainer.callback_metrics["val_loss"].item()

In [8]:
import optuna
from encoder import BSAEncoder


def objective_bsa(trial: optuna.trial.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }
    params = {
        "win_size": trial.suggest_int("win_size", 1, 16),
        "cutoff": trial.suggest_float("cutoff", 0.01, 1),
        "threshold": trial.suggest_float("threshold", 0.01, 1),
        "normalize": False
    }
    
    spike_encoder = BSAEncoder(**params)

    model = LitSTFTPreprocessedSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=15,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
    )

    trainer.fit(model, datamodule=datamodule)

    return trainer.callback_metrics["val_loss"].item()

In [9]:
import optuna
from encoder import PhaseEncoderExpand
from optuna.integration import PyTorchLightningPruningCallback
from model.classifier import ModelConfig
from model.lightning import OptimizerConfig, LitSTFTPreprocessedSeizureClassifier
import lightning.pytorch as pl


def objective_pe(trial: optuna.trial.Trial) -> float:
    model_params: ModelConfig = {
        "threshold": trial.suggest_float("threshold", 0.01, 0.5),
        "slope": trial.suggest_float("slope", 1.0, 20.0),
        "beta": trial.suggest_float("beta", 0.1, 0.99),
        "dropout_rate1": trial.suggest_float("dropout_rate1", 0.1, 0.99),
        "dropout_rate2": trial.suggest_float("dropout_rate2", 0.1, 0.99),
    }

    optimizer_params: OptimizerConfig = {
        "lr": trial.suggest_float("lr", 1e-6, 1e-4, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
        "scheduler_factor": trial.suggest_float("scheduler_factor", 0.1, 0.99),
        "scheduler_patience": trial.suggest_int("scheduler_patience", 1, 10),
    }
    params = {
        "phase_window": trial.suggest_int("phase_window", 1, 16),
        "normalize": False
    }

    spike_encoder = PhaseEncoderExpand(**params)

    model = LitSTFTPreprocessedSeizureClassifier(
        model_config=model_params,
        optimizer_config=optimizer_params,
        spike_encoder=spike_encoder,
    )

    trainer = pl.Trainer(
        max_epochs=15,
        accelerator="auto",
        devices="auto",
        strategy="auto",
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")],
    )

    trainer.fit(model, datamodule=datamodule)

    return trainer.callback_metrics["val_loss"].item()

In [10]:
from config import DB_CONN_STRING
import optuna

def create_and_run_study(study_name: str, objective, n_trials: int):

    sampler = optuna.samplers.CmaEsSampler()
    study = optuna.create_study(
        direction="minimize",
        study_name=study_name,
        sampler=sampler,
        storage=DB_CONN_STRING,
        load_if_exists=True,
        pruner=optuna.pruners.HyperbandPruner()
    )
    study.optimize(objective, n_trials=n_trials)
    return study

In [11]:
experiment = [
    ("Classifier_SF_Tuning", objective_sf, 100),
    ("Classifier_BSA_Tuning", objective_bsa, 100),
    ("Classifier_PE_Tuning", objective_pe, 100)
]

In [None]:
for name, objective_fn, n_trials in experiment:
    create_and_run_study(name, objective_fn, n_trials)

[I 2025-05-01 01:46:49,388] A new study created in RDB with name: Classifier_SF_Tuning
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type               | Params | Mode 
-----------------------------------------------------
0 | model | EEGSpikeClassifier | 824 K  | train
-----------------------------------------------------
824 K     Trainable params
0         Non-trainable pa

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]