# MelSpecClassifier

Use the spectogram of the wav file and use a CNN with 2DConv to classify the genre.
The spectograms have already been provided in the GTZAN dataset. It is cropped before being passed into the model

In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import tempfile
from PIL import Image
import torchvision.transforms as transforms
import torch.utils.data as Data
import os
from PIL import ImageOps
from torch.utils.data import SubsetRandomSampler
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import ASHAScheduler

## Constant parameters used in training

Run `setup.sh` to mount Google Drive containing GTZAN

In [ ]:
GTZAN_MEL = "/content/drive/MyDrive/GTZAN/Data/images_original/"

PREPROCESS_CROP = (54, 35, 42, 35)

IMAGE_INPUT_DIMENSIONS = [432, 288]
GENRES = {'blues': 0, 'classical': 1, 'country': 2, 'disco': 3,
          'hiphop': 4, 'jazz': 5, 'metal': 6, 'pop': 7, 'reggae': 8,
          'rock': 9}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device", DEVICE)

Create a `Dataset` for the mel-spectograms

In [ ]:
class ImageDataset(Data.Dataset):
    def __init__(self):
        self.images = []
        self.labels = []

        # Go through all songs and tag X (tensor of image), Y as genre.
        for genre in os.listdir(GTZAN_MEL):
            for song in os.listdir(os.path.join(GTZAN_MEL, genre)):
                abs_path = os.path.join(GTZAN_MEL, genre, song)
                image = Image.open(abs_path)

                # The images have been obtained in the dataset by using the mel spectogram (librosa)
                # Cropping the image to only contain the spectogram to pass into CNN
                image_cropped = ImageOps.crop(image, PREPROCESS_CROP)

                transform = transforms.Compose([transforms.ToTensor()])
                # Convert PIL Image to tensor
                self.images.append(transform(image_cropped))
                # Convert genre tag to associated digit
                self.labels.append(GENRES[genre])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

The `MelSpecTrainer` model used is a CNN with 2 convolutional layers and 2 linear layers.
There is a lack of datapoints compared to the number of dimensions.
To avoid over-training, output features of the linear layer is less, and the number of layers is 2.

In [ ]:
class MelSpecTrainer(nn.Module):
    def __init__(self, l1=256, l2=20):
        super().__init__()

        self.current_dimensions = IMAGE_INPUT_DIMENSIONS

        self.conv_layer_1 = nn.Sequential(nn.Conv2d(4, 32, 3),
                                          nn.ReLU(),
                                          nn.MaxPool2d(kernel_size=2, stride=3)
                                          )

        self.conv_layer_2 = nn.Sequential(nn.Conv2d(32, 16, 3),
                                          nn.ReLU(),
                                          nn.MaxPool2d(kernel_size=2, stride=3)
                                          )

        self.flatten_layer = nn.Flatten()

        self.linear_layer_1 = nn.Sequential(nn.Linear(12320, l1),
                                            nn.ReLU())

        self.linear_layer_2 = nn.Sequential(nn.Linear(l1, l2),
                                            nn.ReLU())

        self.classifier = nn.Linear(l2, 10)

    def forward(self, x):
        # First 2D convolution layer
        x = self.conv_layer_1(x)
        # Second 2D convolution layer
        x = self.conv_layer_2(x)

        # Linear layer and classifier
        x = self.flatten_layer(x)
        x = self.linear_layer_1(x)
        x = self.linear_layer_2(x)
        x = self.classifier(x)

        return x

Use the mel spectrogram to train a model for classification of genre of the wav file
Split into test/train/validation.

In [ ]:
def dataset_split(image_dataset):
    indices = list(range(len(image_dataset)))
    random.seed(42)
    random.shuffle(indices)
    
    num_train = int(len(image_dataset) * 0.8)
    num_validation = int(len(image_dataset) * 0.1)
    train_indices = indices[:num_train]
    test_and_validation = indices[num_train:]
    validation_indices = test_and_validation[:num_validation]
    test_indices = test_and_validation[num_validation:]
    
    return test_indices, train_indices, validation_indices

Create routines for training and validation. Perform Hyperparameter Tuning to devise a closer to optimized model.

In [ ]:
def train_mel_spec_model(config):
    
    model = MelSpecTrainer(l1=config["l1"], l2=config["l2"])
    model.to(DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), config["l1"])
    image_dataset = ImageDataset()
    
    _, train_indices, validation_indices = dataset_split(image_dataset)
    
    # Create test and train datasets
    train_sampler = SubsetRandomSampler(train_indices)
    validation_sampler = SubsetRandomSampler(validation_indices)
    
    train_dataset = Data.DataLoader(image_dataset, batch_size=config["batch_size"], sampler=train_sampler)
    
    validation_dataset = Data.DataLoader(image_dataset, batch_size=config["batch_size"], sampler=validation_sampler)
    
    for epoch in range(config["num_epochs"]):
        for batch_id, curr_batch in enumerate(train_dataset):
            # Predict and get loss
            images, labels = curr_batch[0].to(DEVICE), curr_batch[1].to(DEVICE)
            pred = model(images)
            loss = loss_fn(pred, labels)

            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"epoch: {epoch}, batch_id: {batch_id}, loss: {loss}")
    
        # Validation loss
        # Calculate avg.loss and accuracy for all datapoints in validation set.
        # Based on https://docs.ray.io/en/latest/tune/examples/tune-pytorch-cifar.html
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for data in validation_dataset:
            with torch.no_grad():
                images, labels = data
                images, labels = images.to(DEVICE), labels.to(DEVICE)
    
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
                loss = loss_fn(outputs, labels)
                val_loss += loss
                val_steps += 1
                
        # Construct checkpoint for Ray Tuner
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
            torch.save(
                (model.state_dict(), optimizer.state_dict()), path
            )
            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            train.report(
                {"loss": (val_loss / val_steps), "accuracy": correct / total},
                checkpoint=checkpoint,
            )
    
        print(f"Validation Loss: {val_loss / val_steps}, Accuracy: {correct / total}")
            

Create routine for testing model. The split being used is 80% for training, 10% for validation, and 10% for testing. 

In [ ]:
def test_mel_spec_model(best_result):
    best_model = MelSpecTrainer(l1=best_result.config["l1"], l2=best_result.config["l2"])
    best_model.to(DEVICE)
    
    # Get model state from best result
    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
    model_state, _ = best_model.load(checkpoint_path)
    image_dataset = ImageDataset()
    
    # Create test data loader
    test_indices, _, _ = dataset_split(image_dataset)
    test_sampler = SubsetRandomSampler(test_indices)
    test_dataset = Data.DataLoader(image_dataset, batch_size=5, sampler=test_sampler)
    
    total = 0
    correct = 0
    for data in test_dataset:
        with torch.no_grad():
            images, labels = data
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            outputs = best_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f"Best Model Accuracy {(correct * 100) / total}%")

# Main function

Here, we specify the range for the hyperparameters we want Ray Tune to tune on. Run the training of the model using various hyperparameters.

Test the model using the best trained model as obtained using Ray Tune

In [ ]:
def run_mel_spec_classifier():
    config = {
        "l1": ..., 
        "l2": ...,
        "lr": ...,
        "batch_size": ...
    }
    
    # Only stop trials at least after 20 training iterations
    asha_scheduler = ASHAScheduler(time_attr='training_iteration',
                                   grace_period=20)
    
    tuner = tune.Tuner(tune.with_resources(tune.with_parameters(train_mel_spec_model)),
                       tune_resources=tune.TuneConfig(
                           metric='loss',
                           mode="min",
                           scheduler=asha_scheduler,
                           num_samples=5,
                       ),
                       param_space=config,)
    
    results = tuner.fit()
    best_result = results.get_best_result("loss", "min")
    
    test_mel_spec_model(best_result)
    
run_mel_spec_classifier()