In [1]:
from dataset import CHBMITDataset

data_path = "./CHB-MIT/processed"
dataset = CHBMITDataset(data_path)

In [2]:
from torch.utils.data import DataLoader

data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [3]:
test_data, test_label = next(iter(data_loader))

In [4]:
import torch
from utils.preprocess import VectorizeSTFT

stft_data = VectorizeSTFT(test_data)
magnitudes = torch.abs(stft_data)

In [5]:
def normalize(x: torch.Tensor) -> torch.Tensor:
    x_min = x.min(dim=-1, keepdim=True).values
    x_max = x.max(dim=-1, keepdim=True).values

    diff = x_max - x_min
    diff[diff == 0] = 1.0

    return (x - x_min) / diff

In [6]:
normalized_magnitudes = normalize(magnitudes)

In [7]:
from encoder import BSAEncoder

bsa_encoder = BSAEncoder()
encoded_data = bsa_encoder.encode(magnitudes)
decoded_data = bsa_encoder.decode(encoded_data)

In [8]:
decoded_data.shape

torch.Size([32, 22, 129, 65])

In [9]:
from utils.snr import SNRCalculator

SNRCalculator.calculate_overall_snr(magnitudes, decoded_data)

-49.40974426269531

### Decoded with tuned parameter

In [10]:
params = {
    "win_size": 5,
    "cutoff": 0.06122595496894435,
    "threshold": 0.39005882375740064,
}

tuned_bsa_encoder = BSAEncoder(**params)

In [11]:
tuned_decoded_data = tuned_bsa_encoder.decode(encoded_data)

In [12]:
SNRCalculator.calculate_overall_snr(magnitudes, tuned_decoded_data)

-50.42816925048828