In [1]:
import torch
from torch.utils.data import Dataset
import torchaudio
import pandas as pd
import os
import torchvision

In [2]:
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift,Shift

ModuleNotFoundError: No module named 'audiomentations'

In [None]:
SAMPLE_RATE = 22050

In [None]:
class RepeatChannelsTransform:
    def __call__(self,img):
        return img.repeat(3, 1, 1)
min_max_scaling = torchvision.transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min()))

In [None]:
default_augmenter = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.1,p=0.2),
    Shift(min_shift=-1, max_shift=1, p=0.7),
    TimeStretch(min_rate=0.8, max_rate=1.2, p=0.6),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.6)
])

In [6]:
class ESCDataset(Dataset):
    def __init__(self,
                 annotations_file,
                 audio_dir,
                 train,
                 target_sample_rate = SAMPLE_RATE,
                 augmenter = default_augmenter,
                 device = 'cpu'):
        self.device = device
        self.annotations = pd.read_csv(annotations_file)
        self.train = train
        if train:
            self.annotations = self.annotations[self.annotations['fold'] != 5]
        else :
            self.annotations = self.annotations[self.annotations['fold'] == 5]
        self.audio_dir = audio_dir
        self.augmenter = augmenter
        self.target_sample_rate = target_sample_rate
        self.define_transforms()
    
    def define_transforms(self):
        self.mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate = SAMPLE_RATE,
                n_fft = 2048,
                hop_length = 512,
                n_mels = 128
            ).to(self.device)
        self.log_spectrogram_transform = torchaudio.transforms.AmplitudeToDB().to(self.device)
        self.vision_transforms = torchvision.transforms.Compose([
                torchvision.transforms.Resize((224, 224),antialias=True).to(self.device),
                RepeatChannelsTransform(),
                min_max_scaling
            ])

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self,idx):
        sample_path = self._get_sample_path(idx)
        label = self._get_sample_label(idx)
        signal, sr = torchaudio.load(sample_path)
        if self.train:
            signal = torch.from_numpy(self.augmenter(signal.numpy(),sr))
        signal = signal.to(self.device)
        signal = self._fix_sample_rate(signal,sr)
        
        signal = self.mel_spectrogram_transform(signal)
        signal = self.vision_transforms(signal)        
        return signal,label

    def _fix_sample_rate(self,signal,sr):
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr,self.target_sample_rate).to(self.device)
            signal = resampler(signal)
        return signal

    def _get_sample_path(self,idx):
        path = os.path.join(self.audio_dir,self.annotations.iloc[idx,0])
        return path

    def _get_sample_label(self,idx):
        return self.annotations.iloc[idx,2]