# BSA encoder testing

In [1]:
import torch
from utils.preprocess import VectorizeSTFT


def normalize(x: torch.Tensor) -> torch.Tensor:
    x_min = x.min(dim=-1, keepdim=True).values
    x_max = x.max(dim=-1, keepdim=True).values

    diff = x_max - x_min
    diff[diff == 0] = 1.0

    return (x - x_min) / diff


def preprocess_data(x: torch.Tensor) -> torch.Tensor:
    stft_data = VectorizeSTFT(x)
    magnitudes = torch.abs(stft_data)
    normalized = normalize(magnitudes)
    return normalized

In [2]:
from dataset import CHBMITDataset
from torch.utils.data import Dataset


class PreparedDataset(Dataset):
    def __init__(self, dataset: CHBMITDataset) -> None:
        data = preprocess_data(dataset.data)
        self.data = data
        self.labels = dataset.labels

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> torch.Tensor:
        eeg_raw = self.data[idx]  # EEG data of shape (22, 2048)
        label = self.labels[idx].bool()  # Label: 0 (interictal) or 1 (ictal)
        return eeg_raw, label

In [3]:
data_path = "./CHB-MIT/processed"
dataset = CHBMITDataset(data_path)
prepared_dataset = PreparedDataset(dataset)

In [4]:
from torch.utils.data import DataLoader

data_loader = DataLoader(prepared_dataset, batch_size=32, shuffle=False, num_workers=4)

In [14]:
import optuna
from tqdm import tqdm

from encoder import PhaseEncoderExpand

def objective(trial: optuna.Trial) -> float:
    params = {
        "phase_window": trial.suggest_int("phase_window", 1, 16),
        "normalize": False
    }

    pe_encoder = PhaseEncoderExpand(**params)

    acc_se = []

    val_loop = tqdm(data_loader, leave=False)
    for data, _ in val_loop:
        encoded = pe_encoder.encode(data)
        decoded = pe_encoder.decode(encoded)
        se = torch.mean((data - decoded) ** 2).item()
        del data, decoded
        acc_se.append(se)

    rmse = torch.sqrt(torch.tensor(acc_se).mean()).item()
    return rmse

In [15]:
from config import DB_CONN_STRING

study = optuna.create_study(
    direction="minimize",
    study_name="Phase Encoder",
    storage=DB_CONN_STRING,
    load_if_exists=True,
)
# study = optuna.create_study(
#     direction="minimize",
#     study_name="Step Forward Encoder",
# )

[I 2025-04-29 01:16:18,981] Using an existing study with name 'Phase Encoder' instead of creating a new one.


In [None]:
study.optimize(objective, n_trials=50)

  3%|▎         | 5/157 [00:02<01:16,  1.99it/s]