In [1]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [2]:
import optuna
import torch

from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataset

torch.set_float32_matmul_precision("high")

[32m2025-05-18 02:20:02.859[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/jupyter-group55/snn-encoder-test[0m


In [3]:
# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

In [4]:
import torch
from torch.utils.data import DataLoader, random_split

generator = torch.manual_seed(47)
train_dataset, test_dataset = random_split(dataset, [0.8, 0.2], generator=generator)

In [5]:
train_loader = DataLoader(
    train_dataset,
    batch_size=512,
    shuffle=True,
    pin_memory=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=512,
    shuffle=False,
    pin_memory=True,
)

In [16]:
tuning_list = [
    "be",
    "pe",
    "poisson",
    "bsa",
    "sf",
    "tbr",
]

In [24]:
import math

from tqdm.notebook import tqdm

from eeg_snn_encoder.tuning import ENCODER_TUNING_FUNCTIONS


def create_encoder_objective(encoder_type: str):
    if encoder_type not in ENCODER_TUNING_FUNCTIONS:
        valid_types = list(ENCODER_TUNING_FUNCTIONS.keys())
        raise ValueError(f"Unsupported encoder type: {encoder_type}. Choose from: {valid_types}")

    def objective(trial: optuna.Trial):
        encoder = ENCODER_TUNING_FUNCTIONS[encoder_type](trial)

        data = tqdm(train_loader, desc=f"Training {encoder_type} encoder", leave=False)

        loss_sum = 0
        batch_count = 0


        for idx, batch in enumerate(data):
            x, y = batch
            x: torch.Tensor = x.to(device="cuda")
            y: torch.Tensor = y.to(device="cuda")

            # Forward pass
            encoded_data = encoder.encode(x)

            decoded_params = encoder.get_decode_params(x)

            decoded_data = encoder.decode(encoded_data, decoded_params)[...,:x.shape[3]]

            # Compute rmse
            squared_error = torch.sum((decoded_data - x) ** 2)

            loss_sum += squared_error.item()
            batch_count += x.numel()

            trial.report(math.sqrt(loss_sum / batch_count), step=idx)
            if trial.should_prune():
                data.container.close()
                raise optuna.TrialPruned()

        # Compute the average loss
        avg_loss = loss_sum / batch_count
        rmse = math.sqrt(avg_loss)

        return rmse

    return objective

In [26]:
import os

from IPython.display import clear_output
from loguru import logger

for tuning in tuning_list:
    clear_output(wait=True)
    logger.info(f"Starting tuning encoder for {tuning}")

    sampler = optuna.samplers.TPESampler()

    pruner = optuna.pruners.HyperbandPruner()

    study = optuna.create_study(
        direction="minimize",
        study_name=f"encoder-tuning-{tuning}",
        storage=os.environ["OPTUNA_CONN_STRING_CPE"],
        load_if_exists=True,
        sampler=sampler,
        pruner=pruner,
    )

    objective = create_encoder_objective(tuning)

    study.optimize(objective, n_trials=250)

    logger.info(f"Finished tuning encoder for {tuning}")