In [None]:
from typing import Optional

from matplotlib import pyplot as plt
from torch import Tensor

from src.data.loader import get_dataloader
from src.utils import load_config

In [None]:
config = load_config()
config.batch_size = 1
dataloaders = [
    get_dataloader(config, subset="validation", snr_db=None, shuffle=False),
    get_dataloader(config, subset="validation", snr_db=10, shuffle=False),
    get_dataloader(config, subset="validation", snr_db=-10, shuffle=False),
]
dataset = dataloaders[0].dataset

In [None]:
def find_closest_index_for_label(_label: str) -> Optional[int]:
    for i in range(len(dataset)):
        if dataset.samples[i]["label"] == _label:
            return i
    return None


def visualize_waveforms(_label: str) -> None:
    index = find_closest_index_for_label(_label)

    if index is None:
        raise ValueError(f"Label {_label} not found in dataset.")

    for dl in dataloaders:
        waveform = dl.dataset[index][0].squeeze(0)
        plot_waveform(waveform, title=f"Class `{_label}` waveform with SNR {dl.dataset.snr_db}")


def visualize_spectrograms(_label: str) -> None:
    index = find_closest_index_for_label(_label)

    if index is None:
        raise ValueError(f"Label {_label} not found in dataset.")

    for dl in dataloaders:
        for i, (x, _) in enumerate(dl):
            if i == index:
                mel_spectrogram = x.squeeze(0).squeeze(0)
                plot_spectrogram(
                    mel_spectrogram,
                    title=f"Class `{_label}` Mel-spectrogram with SNR {dl.dataset.snr_db}"
                )
                break


def plot_waveform(waveform: Tensor, title: str) -> None:
    plt.figure(figsize=(10, 4))
    plt.plot(waveform.numpy())
    plt.title(title)
    plt.xlabel("Time (samples)")
    plt.ylabel("Amplitude")
    plt.show()


def plot_spectrogram(spectrogram: Tensor, title: str) -> None:
    plt.figure(figsize=(10, 4))
    plt.imshow(spectrogram.numpy(), aspect="auto", origin="lower", cmap="viridis")
    plt.title(title)
    plt.xlabel("Time (frames)")
    plt.ylabel("Mel Frequency Bins")
    plt.colorbar(format="%+2.0f dB")
    plt.show()

In [None]:
for label in config.classes:
    visualize_waveforms(label)
    visualize_spectrograms(label)