In [None]:
%reload_ext autoreload
%autoreload all

%aimport -torch
%aimport -matplotlib
%aimport -seaborn
%aimport -numpy
%aimport -pandas
%aimport -scipy
%aimport -lightning 

In [None]:
import optuna
import torch

from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataModule, CHBMITDataset

torch.set_float32_matmul_precision('high')

In [3]:
# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

In [4]:
datamodule = CHBMITDataModule(dataset, batch_size=768, worker=0)

In [8]:
tuning_list = [
    "be",
    "pe",
    "poisson",
    "bsa",
    "sf",
    "tbr",
]

In [9]:
import os

from loguru import logger

from eeg_snn_encoder.tuning import create_objective

for tuning in tuning_list:
    logger.info(f"Starting tuning for {tuning}")
    
    sampler = optuna.samplers.TPESampler(n_startup_trials=10, multivariate=True, group=True)

    pruner = optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=1)

    study = optuna.create_study(
        direction="minimize",
        study_name=f"model-tuning-{tuning}-new",
        storage=os.environ["OPTUNA_CONN_STRING"],
        load_if_exists=True,
        sampler=sampler,
        pruner=pruner,
    )

    objective = create_objective(
        encoder_type=tuning, datamodule=datamodule, moniter_metric="val_mse", moniter_mode="min"
    )

    study.optimize(objective, n_trials=50)

    logger.info(f"Best trial for {tuning}: {study.best_trial}")
    logger.info(f"Best value for {tuning}: {study.best_value}")
    logger.info(f"Best params for {tuning}: {study.best_params}")

ImportError: cannot import name 'create_objective' from 'eeg_snn_encoder.tuning' (/home/jupyter-group55/snn-encoder-test/eeg_snn_encoder/tuning.py)