In [4]:
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader, Dataset
from speechbrain.pretrained import EncoderClassifier

import numpy as np

In [5]:
DATA_PATH = Path("../data")
WEIGHTS_PATH = Path("speechbrain/google_speech_command_xvector")
EXP_NAME = WEIGHTS_PATH.name

MAX_AUDIO_LEN = 16000  # в отсчётах sr
DEVICE = "cuda"
BATCH_SIZE = 128
N_EPOCHS = 100
VAL_ITER = 10
CLASSES = {
    'yes': 0, 
    'no': 1, 
    'up': 2, 
    'down': 3, 
    'left': 4, 
    'right': 5, 
    'on': 6, 
    'off': 7, 
    'stop': 8, 
    'go': 9
}

In [6]:
enc_classifier = EncoderClassifier.from_hparams(
    source=WEIGHTS_PATH,
    savedir=Path("pretrained_models") / EXP_NAME,
    run_opts={"device": DEVICE},
)

In [None]:
audio_normalizer = enc_classifier.audio_normalizer

feature_extractor1 = enc_classifier.mods.compute_features
normalizer = enc_classifier.mods.mean_var_norm
embedding_model = enc_classifier.mods.embedding_model

In [None]:
class TrainData(Dataset):
    def __init__(self, audio_dir: Path, noise_dir: Path) -> None:
        super().__init__()
        self.audio_len = MAX_AUDIO_LEN

        self.audios = list()
        self.noises = list()
        self.classes = list()

        for class_dir in audio_dir.iterdir():
            for wav_path in class_dir.iterdir():
                audio, _ = torchaudio.load(wav_path)
                self.audios.append(audio)
                self.classes.append(class_dir.name)

        for wav_path in noise_dir.iterdir():
            noise, _ = torchaudio.load(wav_path)
            self.noises.append(noise)

    def __len__(self):
        return len(self.classes)

    def __getitem__(self, idx):
        noise_idx = np.random.randint(0, len(self.noises) - 1)
        noise = self.noises[noise_idx]

        if noise.shape[0] == 2:
            noise = noise[np.random.randint(0, 2)] # TODO: try mean instead random
            noise = noise.unsqueeze(0)

        audio = self.audios[idx]
        audio = self.__pad_audio(audio)
        audio = self.__add_noise(audio, noise, 0, 6)
        audio = audio / audio.abs().max()

        melspec = self.mel_creator(audio)

        return melspec, CLASSES.index(self.classes[idx])

    def __pad_audio(self, audio):
        if self.audio_len - audio.shape[-1] > 0:
            i = np.random.randint(0, self.audio_len - audio.shape[-1])
        else:
            i = 0
        pad_patern = (i, self.audio_len - audio.shape[-1] - i)
        audio = F.pad(audio, pad_patern, "constant").detach()
        return audio

    def __add_noise(self, clean, noise, min_amp, max_amp):
        noise_amp = np.random.uniform(min_amp, max_amp)
        # так как шумная запись длиннее, то выбираем случайный момент начала шумной записи
        start = np.random.randint(0, noise.shape[1] - clean.shape[1] + 1)
        noise_part = noise[:, start : start + clean.shape[1]]

        if noise_part.abs().max() == 1:
            return clean

        # накладываем шум
        noise_mult = clean.abs().max() / noise_part.abs().max() * noise_amp
        return (clean + noise_part * noise_mult) / (1 + noise_amp)


class TestData(Dataset):
    def __init__(self, audio_dir: Path, markup_path) -> None:
        super().__init__()
        self.audio_len = MAX_AUDIO_LEN
        self.mel_creator = MelCreator()

        self.audio_paths = list()
        self.classes = list()
        markup = pd.read_csv(markup_path)
        for file_name, category in markup.values:
            audio, _ = torchaudio.load(audio_dir / f"{file_name}.wav")
            self.audio_paths.append(audio)
            self.classes.append(category)

    def __len__(self):
        return len(self.classes)

    def __getitem__(self, idx):
        audio = self.audio_paths[idx]
        audio = self.__pad_audio(audio)
        audio = audio / audio.abs().max()
        melspec = self.mel_creator(audio)

        return melspec, CLASSES.index(self.classes[idx])

    def __pad_audio(self, audio):
        if self.audio_len - audio.shape[-1] > 0:
            i = np.random.randint(0, self.audio_len - audio.shape[-1])
        else:
            i = 0
        pad_patern = (i, self.audio_len - audio.shape[-1] - i)
        audio = F.pad(audio, pad_patern, "constant").detach()

        return audio
