# BSA encoder testing

In [6]:
from dataset import CHBMITDataset

data_path = "./CHB-MIT/processed"
dataset = CHBMITDataset(data_path)

In [2]:
import torch

from utils.preprocess import VectorizeSTFT


def normalize(x: torch.Tensor) -> torch.Tensor:
    x_min = x.min(dim=-1, keepdim=True).values
    x_max = x.max(dim=-1, keepdim=True).values

    diff = x_max - x_min
    diff[diff == 0] = 1.0

    return (x - x_min) / diff


def preprocess_data(x: torch.Tensor) -> torch.Tensor:
    stft_data = VectorizeSTFT(x)
    magnitudes = torch.abs(stft_data)
    normalized = normalize(magnitudes)
    return normalized

In [4]:
from torch.utils.data import Dataset


class PreparedDataset(Dataset):
    def __init__(self, dataset: CHBMITDataset) -> None:
        data = preprocess_data(dataset.data.to(device="cuda"))
        self.data = data
        self.labels = dataset.labels

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> torch.Tensor:
        eeg_raw = self.data[idx]  # EEG data of shape (22, 2048)
        label = self.labels[idx].bool()  # Label: 0 (interictal) or 1 (ictal)
        return eeg_raw, label

In [5]:
# from torch.utils.data import DataLoader

prepared_dataset = PreparedDataset(dataset)
del dataset
# data_loader = DataLoader(prepared_dataset, batch_size=1024, shuffle=True, num_workers=8)

RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [5]:
import optuna
from encoder import BSAEncoder
from utils.snr import SNRCalculator


def objective(trial: optuna.Trial) -> float:
    params = {
        "win_size": trial.suggest_int("win_size", 1, 16),
        "cutoff": trial.suggest_float("cutoff", 0.01, 0.99),
        "threshold": trial.suggest_float("threshold", 0.01, 0.99),
    }

    bsa_encoder = BSAEncoder(**params)

    batch_size = 32  # or whatever fits
    decoded_batches = []

    with torch.no_grad():
        # Encode the whole dataset ONCE
        encoded_data = bsa_encoder.encode(prepared_dataset.data)

        # Now batch only the decoding
        for start in range(0, encoded_data.size(0), batch_size):
            end = min(start + batch_size, encoded_data.size(0))
            encoded_batch = encoded_data[start:end]

            decoded_batch = bsa_encoder.decode(encoded_batch)
            decoded_batches.append(decoded_batch)

    # After all decoding is done, concatenate
    all_decoded = torch.cat(decoded_batches, dim=0)

    snr = SNRCalculator.calculate_overall_snr(encoded_data, all_decoded)
    # # Final MSE
    # mse = torch.nn.functional.mse_loss(all_decoded, prepared_dataset.data).item()
    # return mse
    return snr

In [6]:
from config import DB_CONN_STRING

study = optuna.create_study(
    direction="maximize",
    study_name="BSA SNR metric",
    storage=DB_CONN_STRING,
    load_if_exists=True,
)
# study = optuna.create_study(
#     direction="minimize",
#     study_name="BSA mse Memory test",
# )

[I 2025-04-29 00:46:04,333] Using an existing study with name 'BSA SNR metric' instead of creating a new one.


In [None]:
study.optimize(objective, n_trials=50)

[I 2025-04-29 00:46:33,568] Trial 11 finished with value: 0.6063430309295654 and parameters: {'win_size': 5, 'cutoff': 0.26505247213031147, 'threshold': 0.5021550717154726}. Best is trial 6 with value: 0.8596798181533813.
[I 2025-04-29 00:47:23,823] Trial 12 finished with value: -1.2047570943832397 and parameters: {'win_size': 12, 'cutoff': 0.6882673778640046, 'threshold': 0.7442478075180808}. Best is trial 6 with value: 0.8596798181533813.


In [None]:
study.best_params