In [1]:
from argparse import Namespace

import torch
from tqdm import tqdm

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"We are using the following the device to train: {device}")

# initialize an empty argparse Namespace in which we can store argumens for training
args = Namespace()

# every piece of audio has a length of 59049 samples
args.audio_length = 59049

# the sample rate of our audio
args.sample_rate = 22050

  from .autonotebook import tqdm as notebook_tqdm


We are using the following the device to train: cuda


There are 14 pitch classes as following: min, maj, dim, aug, min6, maj6, min7, minmaj7, maj7, 7, dim7, hdim7, sus2, sus4.

In [None]:
import os
import random

import numpy as np
import soundfile as sf
import torch
from torch.utils import data
from torchaudio_augmentations import (
    Compose,
    Delay,
    Gain,
    HighLowPass,
    Noise,
    PitchShift,
    PolarityInversion,
    RandomApply,
    RandomResizedCrop,
    Reverb,
)

PITCH_CLASS = ['7', 'min', 'maj', 'dim', 'aug', 'min6', 'maj6', 'min7', 'maj7', 'minmaj7', 'dim7', 'hdim7', 'sus2', 'sus4']
ROOT_CLASS = ['A', 'Ab', 'B', 'Bb', 'C', 'C#', 'D', 'D#', 'Db', 'E', 'Eb', 'F', 'F#', 'G', 'G#', 'Gb', 'N']

class GTZANDataset(data.Dataset):
    def __init__(self, data_path, split, num_samples, num_chunks, is_augmentation):
        self.data_path = data_path if data_path else ""
        self.split = split
        self.num_samples = num_samples
        self.num_chunks = num_chunks
        self.is_augmentation = is_augmentation
        self.genres = GTZAN_GENRES
        self._get_song_list()
        if is_augmentation:
            self._get_augmentations()

    def _get_song_list(self):
        list_filename = os.path.join(self.data_path, "%s_filtered.txt" % self.split)
        with open(list_filename) as f:
            lines = f.readlines()
        self.song_list = [line.strip() for line in lines]

    def _get_augmentations(self):
        transforms = [
            RandomResizedCrop(n_samples=self.num_samples),
            RandomApply([PolarityInversion()], p=0.8),
            RandomApply([Noise(min_snr=0.3, max_snr=0.5)], p=0.3),
            RandomApply([Gain()], p=0.2),
            RandomApply([HighLowPass(sample_rate=22050)], p=0.8),
            RandomApply([Delay(sample_rate=22050)], p=0.5),
            RandomApply(
                [PitchShift(n_samples=self.num_samples, sample_rate=22050)], p=0.4
            ),
            RandomApply([Reverb(sample_rate=22050)], p=0.3),
        ]
        self.augmentation = Compose(transforms=transforms)

    def _adjust_audio_length(self, wav):
        if self.split == "train":
            random_index = random.randint(0, len(wav) - self.num_samples - 1)
            wav = wav[random_index : random_index + self.num_samples]
        else:
            hop = (len(wav) - self.num_samples) // self.num_chunks
            wav = np.array(
                [
                    wav[i * hop : i * hop + self.num_samples]
                    for i in range(self.num_chunks)
                ]
            )
        return wav

    def get_augmentation(self, wav):
        return self.augmentation(torch.from_numpy(wav).unsqueeze(0)).squeeze(0).numpy()

    def __getitem__(self, index):
        line = self.song_list[index]

        # get genre
        genre_name = line.split("/")[0]
        genre_index = self.genres.index(genre_name)

        # get audio
        audio_filename = os.path.join(self.data_path, "genres", line)
        wav, fs = sf.read(audio_filename)

        # adjust audio length
        wav = self._adjust_audio_length(wav).astype("float32")

        # data augmentation
        if self.is_augmentation:
            wav_i = self.get_augmentation(wav)
            wav_j = self.get_augmentation(wav)
        else:
            wav_i = wav
            wav_j = wav

        return (wav_i, wav_j), genre_index

    def __len__(self):
        return len(self.song_list)


def get_dataloader(
    data_path=None,
    split="train",
    num_samples=22050 * 29,
    num_chunks=1,
    batch_size=16,
    num_workers=0,
    is_augmentation=False,
):
    is_shuffle = True if (split == "train") else False
    batch_size = batch_size if (split == "train") else (batch_size // num_chunks)
    data_loader = data.DataLoader(
        dataset=GTZANDataset(
            data_path, split, num_samples, num_chunks, is_augmentation
        ),
        batch_size=batch_size,
        shuffle=is_shuffle,
        drop_last=False,
        num_workers=num_workers,
    )
    return data_loader