<div>
    <img src='https://storage.googleapis.com/kaggle-datasets-images/568973/1032238/7ff23ec0b526773506bd5964d4f100d1/dataset-cover.jpg' />
</div>

In [None]:
import numpy as np
import pandas as pd

import os

import torch
from torch import optim
from torch import nn

import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm

from sklearn.utils import shuffle
import matplotlib.pyplot as plt

<h1 id="dataset" style="color:#c7ced6; background:#dc9231; border:0.5px dotted;"> 
    <center>Dataset
        <a class="anchor-link" href="#dataset" target="_self">¶</a>
    </center>
</h1>

In [None]:
class MusicDS(Dataset):
    def __init__(self, path):
        labels = os.listdir(path)
        self.idx_to_labels = {k:v for k,v in enumerate(labels)}
        self.labels_to_idx = {v:k for k,v in enumerate(labels)}
        
        songs_lists = [os.listdir(path + l) for l in labels]
        songs_lists = [list(map(list, zip([path + labels[i] + '/' for a in range(len(sl))],sl))) for i,sl in enumerate(songs_lists)]
        labels = np.array([[s[0].split('/')[-2] for s in l] for l in songs_lists])
        labels = labels.reshape(labels.shape[0] * labels.shape[1])
        
        songs_lists = np.array([[s[0] + s[1] for s in l] for l in songs_lists])
        songs_lists = songs_lists.reshape(songs_lists.shape[0] * songs_lists.shape[1])
        self.labels, self.songs_lists = shuffle(labels, songs_lists)
        
    def plot_specgram(self, waveform, sample_rate, title="Spectrogram", xlim=None):
        waveform = waveform.numpy()

        num_channels, num_frames = waveform.shape
        time_axis = torch.arange(0, num_frames) / sample_rate

        figure, axes = plt.subplots(num_channels, 1)
        if num_channels == 1:
            axes = [axes]
        for c in range(num_channels):
            axes[c].specgram(waveform[c], Fs=sample_rate)
            if num_channels > 1:
                axes[c].set_ylabel(f'Channel {c+1}')
            if xlim:
                axes[c].set_xlim(xlim)
        figure.suptitle(title)
        plt.show(block=False)
        
    def print_stats(self, waveform, sample_rate=None, src=None):
        if src:
            print("-" * 10)
            print("Source:", src)
            print("-" * 10)
        if sample_rate:
            print("Sample Rate:", sample_rate)
        print("Shape:", tuple(waveform.shape))
        print("Dtype:", waveform.dtype)
        print(f" - Max:     {waveform.max().item():6.3f}")
        print(f" - Min:     {waveform.min().item():6.3f}")
        print(f" - Mean:    {waveform.mean().item():6.3f}")
        print(f" - Std Dev: {waveform.std().item():6.3f}")
        print()
        print(waveform)
        print()
        
    def get_sample(self, path, sample_rate=4000):
        effects = [
          ["lowpass", "-1", "150"], # apply single-pole lowpass filter
          ["speed", "0.9"],  # reduce the speed
                             # This only changes sample rate, so it is necessary to
                             # add `rate` effect with original sample rate after this.
          ["rate", f"{sample_rate}"],
          ["reverb", "-w"],  # Reverbration gives some dramatic feeling
        ]
        return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
    
    def __len__(self):
        return len(self.songs_lists)
    
    def __getitem__(self, idx):
        song_path = self.songs_lists[idx]
        try:
            waveform, frame_num = self.get_sample(song_path)
        except:
            idx += 1
            song_path = self.songs_lists[idx]
            waveform, frame_num = self.get_sample(song_path)
        waveform = torch.unsqueeze(waveform, 0)
        waveform = F.interpolate(waveform, size=(300134))
        waveform = torch.squeeze(waveform, 0)
        return waveform, frame_num, self.labels_to_idx[self.labels[idx]]

In [None]:
path = '../input/gtzan-dataset-music-genre-classification/Data/genres_original/'

music_ds = MusicDS(path)
waveform, frame_num, label = music_ds[100]
music_ds.plot_specgram(waveform, frame_num)
music_ds.print_stats(waveform, frame_num)

In [None]:
test_size = int(len(music_ds) * 0.2)
train_size = int(len(music_ds) - test_size)

train_dataset, test_dataset = torch.utils.data.random_split(music_ds, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=4, num_workers=0, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False)

<h1 id="network" style="color:#c7ced6; background:#dc9231; border:0.5px dotted;"> 
    <center>Network
        <a class="anchor-link" href="#network" target="_self">¶</a>
    </center>
</h1>

In [None]:
class Net(nn.Module):
    def __init__(self, n_input=2, n_output=10, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = Net().to(device)

<h1 id="training" style="color:#c7ced6; background:#dc9231; border:0.5px dotted;"> 
    <center>Training
        <a class="anchor-link" href="#training" target="_self">¶</a>
    </center>
</h1>

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

In [None]:
def train(model, epoch):

    losses = []
    for batch_idx, (data, num, target) in enumerate(train_loader):

        data = data.to(device)
        target = target.to(device)

        output = model(data).to(device)

        loss = F.nll_loss(output.squeeze(), target)
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return losses

In [None]:
def number_of_correct(pred, target):
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    return tensor.argmax(dim=-1)


def test(model, epoch):
    model.eval()
    correct = 0

    for data, num, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        output = model(data).to(device)

        pred = get_likely_index(output)
        correct += number_of_correct(pred, target)

    accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy

In [None]:
n_epoch = 80

losses = []
accuracies = []

for epoch in range(1, n_epoch + 1):
    loss = train(net, epoch)
    losses.append(sum(loss) / len(loss))

    accuracy = test(net, epoch)
    accuracies.append(accuracy)
    scheduler.step()

<h1 id="analysis" style="color:#c7ced6; background:#dc9231; border:0.5px dotted;"> 
    <center>Analysis
        <a class="anchor-link" href="#analysis" target="_self">¶</a>
    </center>
</h1>

In [None]:
plt.figure(figsize=(14,8))
plt.title('Accuracies')
plt.plot(accuracies)

In [None]:
plt.figure(figsize=(14,8))
plt.title('Losses')
plt.plot(losses)