In [1]:
import optuna
import torch
import random
import numpy as np

from eeg_snn_encoder.config import PROCESSED_DATA_DIR
from eeg_snn_encoder.dataset import CHBMITDataModule, CHBMITDataset

torch.set_float32_matmul_precision("high")
torch.use_deterministic_algorithms(True, warn_only=True)

torch.manual_seed(47)
random.seed(47)
np.random.seed(47)

[32m2025-06-11 16:21:38.604[0m | [1mINFO    [0m | [36meeg_snn_encoder.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /workspace/snn-encoder-test[0m


In [2]:
# Load the dataset
dataset = CHBMITDataset(PROCESSED_DATA_DIR / "stft_normalized.h5")

In [3]:
datamodule = CHBMITDataModule(dataset, batch_size=768, worker=0)

In [4]:
tuning_encoder = "dummy"

In [5]:
if tuning_encoder not in ["be", "pe", "poisson", "bsa", "sf", "tbr", "dummy"]:
    raise ValueError(f"Unknown encoder type: {tuning_encoder}")

In [6]:
import os

tuned_study_name = f"model-tuning-{tuning_encoder}-new"

tuned_study = optuna.load_study(
    study_name=tuned_study_name,
        storage=os.environ["OPTUNA_CONN_STRING_CPE"],
)

In [13]:
sorted_trial = sorted(tuned_study.trials, key=lambda t: t.value or float("inf"))

In [20]:
import json

deduped_trials = []

param_set = set()

for trial in sorted_trial:
    param_text = json.dumps(trial.params, sort_keys=True)
    if param_text not in param_set:
        param_set.add(param_text)
        deduped_trials.append(trial)

print(f"Total trials: {len(deduped_trials)}")

Total trials: 304


In [None]:
from eeg_snn_encoder.tuning import create_objective
from loguru import logger

logger.info(f"Starting tuning for {tuning_encoder}")

sampler = optuna.samplers.TPESampler()
pruner = optuna.pruners.NopPruner()

study = optuna.create_study(
    direction="minimize",
    study_name=f"model-fine-tuning-{tuning_encoder}-new",
    storage=os.environ["OPTUNA_CONN_STRING"],
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner,
)

objective = create_objective(
    encoder_type=tuning_encoder,
    datamodule=datamodule,
    moniter_metric="val_mse",
    moniter_mode="min",
    epochs=100
)

[I 2025-06-11 16:21:54,415] A new study created in RDB with name: model-fine-tuning-dummy-new


[32m2025-06-11 16:21:54.295[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mStarting tuning for dummy[0m


In [None]:
fine_tune_trials = deduped_trials[:5]

In [None]:
for trial in fine_tune_trials:
    logger.info(f"Starting trial {trial.number} with value {trial.value}")
    study.enqueue_trial(trial.params)

[32m2025-06-11 16:21:54.540[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mStarting trial 132 with value 0.18061089515686035[0m
[32m2025-06-11 16:21:54.559[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mStarting trial 324 with value 0.18193891644477844[0m
[32m2025-06-11 16:21:54.568[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mStarting trial 178 with value 0.18592298030853271[0m
[32m2025-06-11 16:21:54.575[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mStarting trial 305 with value 0.18592298030853271[0m
[32m2025-06-11 16:21:54.580[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mStarting trial 304 with value 0.19521912932395935[0m
[32m2025-06-11 16:21:54.584[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mStarting trial 314 with value 0.19521912932395935[0m
[32m2025-06-11 16:21:54.588[0m |

In [None]:
try:
    study.optimize(objective, n_trials=len(fine_tune_trials), gc_after_trial=True)
    logger.info(f"Best value for {tuning_encoder}: {study.best_value}")
    logger.info(f"Best params for {tuning_encoder}: {study.best_params}")

    logger.info(f"Complete tuning {tuning_encoder} for {len(fine_tune_trials)} trials")
except Exception as e:
    logger.exception(f"Error occurred during optimization for {tuning_encoder}: {e}")

In [24]:
for i in ["be", "pe", "poisson", "bsa", "sf", "tbr", "dummy"]:
    optuna.copy_study(
        from_study_name=f"model-fine-tuning-{i}-new",
        to_study_name=f"model-fine-tuning-{i}-new",
        from_storage=os.environ["OPTUNA_CONN_STRING"],
        to_storage=os.environ["OPTUNA_CONN_STRING_CPE"],
    )

[I 2025-06-11 21:09:51,004] A new study created in RDB with name: model-fine-tuning-be-new


KeyboardInterrupt: 