# BSA encoder testing

In [1]:
from dataset import CHBMITPreprocessedDataset

dataset = CHBMITPreprocessedDataset('./CHB-MIT/processed_data.h5')

In [2]:
from torch.utils.data import DataLoader

data_loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)

In [3]:
import optuna
import torch
from tqdm import tqdm

from encoder import BSAEncoder


def objective(trial: optuna.Trial) -> float:
    params = {
        "win_size": trial.suggest_int("win_size", 1, 16),
        "cutoff": trial.suggest_float("cutoff", 0.01, 1),
        "threshold": trial.suggest_float("threshold", 0.01, 1),
        "normalize": False
    }

    bsa_encoder = BSAEncoder(**params)

    acc_se = []

    val_loop = tqdm(data_loader, leave=False)
    for idx, (data, _) in enumerate(val_loop):
        data = data.to(device="cuda")
        encoded = bsa_encoder.encode(data)
        decoded = bsa_encoder.decode(encoded)
        del encoded
        se = torch.mean((data - decoded) ** 2).item()
        del data, decoded
        acc_se.append(se)

        intermediate_value = torch.tensor(acc_se).mean().item()

        trial.report(intermediate_value, step=idx)

        if trial.should_prune():
            raise optuna.TrialPruned()

    mse = torch.tensor(acc_se).mean().item()
    return mse

In [6]:
from config import DB_CONN_STRING

study = optuna.create_study(
    direction="minimize",
    study_name="BSA-tuning",
    storage=DB_CONN_STRING,
    load_if_exists=True,
    pruner=optuna.pruners.HyperbandPruner()
)

[I 2025-04-29 23:42:31,133] Using an existing study with name 'BSA-tuning' instead of creating a new one.


In [5]:
study.optimize(objective, n_trials=50)

[I 2025-04-29 23:25:01,983] Trial 50 pruned.   
[I 2025-04-29 23:25:02,593] Trial 51 pruned.   
[I 2025-04-29 23:25:03,143] Trial 52 pruned.   
[I 2025-04-29 23:25:03,500] Trial 53 pruned.   
[I 2025-04-29 23:25:04,163] Trial 54 pruned.   
[I 2025-04-29 23:25:04,517] Trial 55 pruned.   
[I 2025-04-29 23:25:04,867] Trial 56 pruned.   
[I 2025-04-29 23:25:05,906] Trial 57 pruned.    
[I 2025-04-29 23:25:06,286] Trial 58 pruned.   
[I 2025-04-29 23:25:07,351] Trial 59 pruned.    
[I 2025-04-29 23:25:09,885] Trial 60 pruned.    
[I 2025-04-29 23:25:10,259] Trial 61 pruned.   
[I 2025-04-29 23:25:14,783] Trial 62 finished with value: 0.029187312349677086 and parameters: {'win_size': 6, 'cutoff': 0.24976747334792332, 'threshold': 0.3258294527482926}. Best is trial 16 with value: 0.028090717270970345.
[I 2025-04-29 23:25:15,804] Trial 63 pruned.    
[I 2025-04-29 23:25:16,177] Trial 64 pruned.   
[I 2025-04-29 23:25:20,755] Trial 65 finished with value: 0.0301541555672884 and parameters: {'wi