In [152]:
import torch.nn as nn
import torch.optim as torch_optim
import torch
from torch import Tensor
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import numpy as np
import os
import json
import re
import librosa

## Genre Classifier Model

In [153]:
class GenreClassifier(nn.Module):
    def __init__(self, inputs):
        super(GenreClassifier, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(inputs.shape[1] * inputs.shape[2], 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, 10)
        
        self.optimizer = torch_optim.Adam(model.parameters(), lr=0.001)
        self.loss      = nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = self.flatten(Tensor(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

### Data Pre-processing

In [154]:
SAMPLE_RATE = 22050
DURATION = 30
SAMPLES_PER_TRACK = SAMPLE_RATE * DURATION
N_FFT = 2048
HOP_LENGTH = 512
N_MFCC = 13
MFCC_OUT_PATH = "mfcc_batches.json"

def segmented_batch_save_mfcc(
    dataset_path,
    n_mfcc=N_MFCC,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_segments=5
):
    extractings = {"mapping": [], "labels": [], "MFCCs": []}
    N_SAMPLES_PER_SEGMENT = int(SAMPLES_PER_TRACK / n_segments)
    EXPECTED_MFCC_PER_SEGMENT = np.ceil(N_SAMPLES_PER_SEGMENT / hop_length)
    
    for genre_idx, (dirpath, _dirnames, filenames) in enumerate(os.walk(dataset_path)):
        if dirpath is not dataset_path:
            dirpath_components = dirpath.split("/")
            semantic_label = dirpath_components[-1]
            extractings["mapping"].append(semantic_label)
            for f in filenames:
                filepath = os.path.join(dirpath, f)
                signal, sr = librosa.load(filepath, sr=SAMPLE_RATE)
                for s in range(n_segments):
                    start_sample = N_SAMPLES_PER_SEGMENT * s
                    finish_sample = start_sample + N_SAMPLES_PER_SEGMENT
                    mfcc = librosa.feature.mfcc(
                        y=signal[start_sample:finish_sample],
                        sr=sr,
                        n_fft=n_fft,
                        n_mfcc=n_mfcc,
                        hop_length=hop_length
                    )
                    MFCC = mfcc.T
                    if len(MFCC) == EXPECTED_MFCC_PER_SEGMENT:
                        extractings["MFCCs"].append(MFCC.tolist())
                        extractings["labels"].append(genre_idx - 1)
    with open(MFCC_OUT_PATH, "w") as out:
        json.dump(extractings, out, indent=4)

### Trainer Class

In [191]:
def load_data(proc_dataset_path=MFCC_OUT_PATH):
    with open(proc_dataset_path, "r") as fp:
        data = json.load(fp)
    inputs = np.array(data["MFCCs"])
    targets = np.array(data["labels"])
    return inputs, targets

class Trainer:
    def __init__(self, model, inputs, targets):
        self.model = model
        self.inputs = inputs
        self.targets = targets
    def split_data(self, train_size=0.8, test_size=0.2):
        assert(train_size+test_size == 1)
        x_train, x_test, y_train, y_test = train_test_split(
            self.inputs,
            self.targets,
            test_size=test_size,
        )
        
        return x_train, x_test, y_train, y_test
    def find_latest_ckpt(self, ckpt_dir):
        ckpts = []
        for filename in os.listdir(ckpt_dir):
            if filename.startswith("ckpt"):
                ckpts.append(filename)
        for ckpt in sorted(ckpts):
            ckpt_num = int(re.findall(r"\d+", ckpt)[0])
        new_ckpt_num = ckpt_num
        return new_ckpt_num
    def train(
        self,
        epochs=10,
        batch_size=64,
        lr=0.0001,
    ):
        latest_ckpt_n = self.find_latest_ckpt("checkpoint")
        msg = "Spinning up from ckpt "
        ckpt_path = "checkpoint/ckpt_" + str(latest_ckpt_n) + ".pth"
        ckpt_state = torch.load(ckpt_path)
        self.model.load_state_dict(ckpt_state)
        
        x_train, x_test, y_train, y_test = self.split_data(train_size=0.7, test_size=0.3)
        for epoch in range(epochs):
            for batch in range(len(x_train)//batch_size):
                batch_x = x_train[batch * batch_size : (batch+1) * batch_size]
                batch_y = y_train[batch * batch_size : (batch + 1) * batch_size]
                x = self.model.forward(batch_x)
                self.model.optimizer.zero_grad()
                loss = self.model.loss(x, torch.tensor(batch_y))
                loss.backward()
                self.model.optimizer.step()

        print("Saving model weights...")
        CKPT_DIR = "checkpoint"
        latest_ckpt_n = self.find_latest_ckpt(CKPT_DIR)
        checkpoint = {
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.model.optimizer.state_dict(),
            "epoch": epoch
        }
        ckpt_file_new = "checkpoint/ckpt_" + str(latest_ckpt_n+1) + ".pth"
        torch.save(self.model.state_dict(), ckpt_file_new)
                

## Perform Pre-processing

In [192]:
UNPROCESSED_DATASET = "example-train"
segmented_batch_save_mfcc(UNPROCESSED_DATASET)

## Perform Training

In [193]:
inputs, targets = load_data(MFCC_OUT_PATH)

In [194]:
model = GenreClassifier(inputs)
trainer = Trainer(model, inputs, targets)

trainer.train()


<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 1 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 2 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 3 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 4 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 5 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 6 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 7 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 8 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 9 complete.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Epoch 10 complete.
Saving model weights...
