## Create custom dataset
<https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>

In [29]:
import torchaudio
import torchvision.io
from sklearn.preprocessing import LabelEncoder
import librosa
from birdclassification.preprocessing.filtering import filter_recordings_30
from torch.utils.data import Dataset
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

class CustomAudioDataset(Dataset):
    def __init__(self, df, recording_dir, transform=None, target_transform=None):
        """
        Parameters
        ----------
        df: pd.DataFrame
            DataFrame of xeno-canto recordings
        recording_dir: str
            filepath to directory with recordings
        transform:
        target_transform:
        """
        
        df['filepath'] = df.apply(lambda x: f"{recording_dir}{x['Latin name']}/{str(x['id'])}.mp3" , axis=1)
        le = LabelEncoder()
        df['label'] = le.fit_transform(df['Latin name'])
        
        self.filepath = df['filepath'].to_numpy()
        self.label = df['label'].to_numpy()
        self.recording_dir = recording_dir
        self.transform = transform
        self.target_transform = target_transform
        self.le_name_mapping = dict(zip(le.classes_, le.transform(le.classes_)))

    def __len__(self):
        return self.filepath.size

    def __getitem__(self, idx):
        audio, sr = torchaudio.load(self.filepath[idx])
        label = self.label[idx]
        
        audio = torch.from_numpy(audio).type(torch.float32)
        label = torch.tensor(label, dtype=torch.int8)
        
        # augmented = pipeline(audio, sr)
        
        if self.transform:
            audio = self.transform(audio)
        if self.target_transform:
            label = self.target_transform(label)       
        
        return audio, label


## Split dataset

In [30]:
RECORDINGS_DIR = '/Users/zosia/Desktop/recordings_30/'

df = filter_recordings_30()

train_df, test_val_df = train_test_split(df, stratify=df['Latin name'], test_size=0.2)
val_df, test_df = train_test_split(test_val_df, stratify=test_val_df['Latin name'], test_size=0.5)

train_ds = CustomAudioDataset(train_df, recording_dir=RECORDINGS_DIR)
val_ds = CustomAudioDataset(val_df, recording_dir=RECORDINGS_DIR)
test_ds = CustomAudioDataset(test_df, recording_dir=RECORDINGS_DIR)

  recordings = pd.read_csv(filepath_recordings)


In [31]:
print(train_ds.__len__())
print(val_ds.__len__())
print(test_ds.__len__())

32968
4121
4121


In [33]:
# for i in range(5):
#     print(train_ds[i])


## DataLoader

In [34]:
train_dataloader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=64, shuffle=True)

In [35]:
# for element in train_dataloader:
#     print(element)

## Example of custom transform on audio

In [36]:
from birdclassification.visualization.plots import plot_waveform
audio, label = train_ds[3]
plot_waveform(audio, 20000)

TypeError: expected np.ndarray (got Tensor)

In [37]:
from birdclassification.preprocessing.spectrogram import generate_mel_spectrogram
from birdclassification.preprocessing.utils import get_loudest_index, cut_around_index
import random
from torchaudio.transforms import Resample
from birdclassification.preprocessing.augmentations_wrappers import InvertPolarity, AddWhiteNoise, PitchShifting, RandomGain, TimeShift, RandomChunk, TimeStretch

class AugmentationsPipeline(torch.nn.Module):
    def __init__(self, target_sr = 32000, n_fft = 512, hop_length = 3 * 128, sample_length = 3, number_of_bands = 64, fmin = 150, fmax = 15000):
        super().__init__()
        self.target_sr = target_sr
        self.augmentations = [InvertPolarity(), 
                              AddWhiteNoise(min_factor=0.1, max_factor=0.8), 
                              RandomGain(min_factor=0.5, max_factor=1.5), 
                              TimeShift(min_factor=0.1, max_factor=0.3), 
                              RandomChunk(sr = target_sr, min_factor=0.1 , max_factor=1), 
                              PitchShifting(sr = target_sr, min_semitones=1, max_semitones=10)]
        
        self.probabilities = [0.5 for i in range(len(self.augmentations))]
        self.get_spectrogram = generate_mel_spectrogram
        self.get_loudest_index = get_loudest_index
        self.cut_around_largest_index = cut_around_index
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.sample_length = sample_length
        self.number_of_bands = number_of_bands
        self.fmin = fmin
        self.fmax = fmax
        
        
    def forward(self, waveform: torch.Tensor) -> torch.Tensor: 
        #select loudest 3 second chunk
        peak = get_loudest_index(waveform, self.n_fft, self.hop_length)
        waveform = cut_around_index(waveform, peak, self.target_sr * self.sample_length)
        
        #augmentations
        if self.augmentations:
            n = random.randint(0, len(self.augmentations))
            selected = random.choices(list(self.augmentations), weights=self.probabilities, k=n)
            print(selected)
            aug = torch.nn.Sequential(*selected)
            waveform = aug(waveform)
        
        waveform = self.mix_down_if_necessary(waveform)
        
        #generate spectrogram
        spectrogram = self.get_spectrogram(waveform, self.target_sr, self.n_fft, self.hop_length)
        
        return spectrogram
    
    def mix_down_if_necessary(self, audio):
        if audio.shape[0] > 1:
            audio = torch.mean(audio, dim = 0,  keepdim=True)
        return audio

In [38]:
pipeline = AugmentationsPipeline()

augmented = pipeline(audio)
plot_waveform(augmented, 20000)

IndexError: tuple index out of range