In [2]:
import pathlib
import torch
import numpy as np
import pandas as pd
from typing import Any
from sklearn.model_selection import train_test_split
import random
from functools import partial
from copy import deepcopy
import torchaudio.transforms as T
from safetensors import safe_open
from torch import nn
from torch_audiomentations import (
    Compose,
    OneOf,
)
from torch.utils.data import DataLoader
from torchvision.transforms import v2
import os
from time import perf_counter

SEED = 654

def seed_everything(seed: int) -> None:
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


seed_everything(SEED)

  torchaudio.set_audio_backend("soundfile")


In [3]:
def get_splits(
        data: pd.DataFrame | np.ndarray | list[...],
        train_size: float,
        valid_size: float,
        test_size: float,
        stratify_col: str | None = None,
) -> tuple[Any, Any, Any]:
    assert train_size + valid_size + test_size <= 1.0

    if stratify_col:
        train_split, valid_test = train_test_split(
            data, train_size=train_size, stratify=data[stratify_col], random_state=SEED
        )
        valid_split, test_split = train_test_split(
            valid_test, train_size=valid_size / (1 - train_size), stratify=valid_test[stratify_col], random_state=SEED
        )
    else:
        train_split, valid_test = train_test_split(data, train_size=train_size, stratify=None, random_state=SEED)
        valid_split, test_split = train_test_split(
            valid_test, train_size=valid_size / (1 - train_size), stratify=None, random_state=SEED
        )

    return train_split, valid_split, test_split

class AddGaussianNoise(torch.nn.Module):
    def __init__(self, mean: float = 0.0, std: float = 1.0, p: float = 0.5) -> None:
        super().__init__()
        assert 0 <= p <= 1
        self.std: float = std
        self.mean: float = mean
        self.p: float = p

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + torch.randn(x.size()) * self.std + self.mean if random.random() < self.p else x


class MinMaxNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_min, x_max = x.min(), x.max()
        new_min, new_max = 0.0, 1.0
        return (x - x_min) / (x_max - x_min) * (new_max - new_min) + new_min


songs_path: list[pathlib.Path] = list(pathlib.Path(os.getcwd()).parent.rglob("*.safetensors"))
train, valid, test = get_splits(songs_path, train_size=0.7, valid_size=0.2, test_size=0.1, stratify_col=None)

In [4]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(
            self,
            data_path: np.ndarray | list[str],
            image_size: int,
            sample_rate: int = 44100,
            crop_size: int = 30,
            mode: str = "train",
    ) -> None:
        assert mode in {"train", "valid", "test"}
        super().__init__()
        self.data_path: np.ndarray | list[str] = data_path
        self.image_size: int = image_size
        self.sample_rate: int = sample_rate
        self.crop_size: int = crop_size
        self.mode: str = mode
        # self._init_transforms()

    def _get_transforms(self, sample_rate) -> tuple[Compose, Compose]:
        transforms = [
            T.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=512,
                win_length=512,
                hop_length=256,
                n_mels=256,
                normalized=True,
            ),
            v2.Resize(size=(self.image_size, self.image_size)),
            MinMaxNorm(),
            v2.ToDtype(torch.float16, scale=False),
        ]

        y_transforms = Compose(transforms)

        if self.mode == "train":
            transforms.insert(0, AddGaussianNoise(p=0.5))
            transforms.insert(2, OneOf([T.TimeMasking(time_mask_param=100), T.FrequencyMasking(freq_mask_param=100)]))
            
        return Compose(transforms), y_transforms

    def __len__(self) -> int:
        return len(self.data_path)

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        # print(self.data_path[index])
        with safe_open(self.data_path[index], framework="pt", device="cpu") as f:
            sample_rate = f.get_tensor("sample_rate")
            audio = f.get_tensor("audio")

        x_transforms, y_transforms = self._get_transforms(sample_rate)
        num_frames: int = audio.shape[1]
        crop_frames: int = self.crop_size * sample_rate

        frame_offset = -1
        x_transformed: torch.Tensor
        y_transformed: torch.Tensor
        if num_frames > crop_frames:
            while True:
                frame_offset: int = random.randint(0, num_frames - crop_frames)
                cropped_audio = audio[:, frame_offset : frame_offset + crop_frames]
                x_transformed, y_transformed = x_transforms(cropped_audio), y_transforms(cropped_audio)
                print(torch.isnan(x_transformed).sum(), torch.isnan(x_transformed).sum())
                if not torch.isnan(x_transformed).sum() and not torch.isnan(x_transformed).sum():
                    break
            return x_transformed, y_transformed
        else:
            return x_transforms(audio), y_transforms(audio)
        # return x_transforms(audio), y_transforms(audio), audio, frame_offset, crop_frames, num_frames

In [5]:
# dataloader = DataLoader(
#     dataset=AudioDataset(data_path=valid[178:179], image_size=256, mode="valid"),
#     batch_size=1,
#     num_workers=0,
#     shuffle=False,
#     pin_memory=True,
#     persistent_workers=False
# )
# for x, y in dataloader:
#     print(x.min(), x.max())
#     print()
#     print(y.min(), y.max())
#     break