# Music genre classifier
## Pretrained classifiers
This notebook should be see as the third step in a series of notebooks aimed to build an ML audio classifier.

We continue our journey of music classification by trainging more complex models, such as CNNs and RNNs.
After this, we will see how our results compare against a pretrained model. 
If you missed our previous steps, you can find them here:

- [preprocessing](https://github.com/pmhalvor/public-data/blob/master/notebooks/music-genre/preprocess.py)
- [traditional classifiers](https://github.com/pmhalvor/public-data/blob/master/notebooks/music-genre/classifiers.py) (note: currently only on branch [add/classifiers](https://github.com/pmhalvor/public-data/blob/add/classifiers/notebooks/music-genre/classifiers.ipynb))


## Goal
Test open-source pretrained models on our dataset, and compare the results to our own models.

## Dataset
The dataset contains 1000 audio tracks each 30 seconds long. It contains 10 genres, each represented by 100 tracks. The tracks were all 22050Hz Mono 16-bit audio files in .wav format.
In [preprocess.py](preprocess.py), we convert the .wav fiels to MFCC features, and store them as PyTorch tensors (`mfcc.pt`). Labels and file paths are stored as numpy-arrays. 

## Source
https://www.kaggle.com/datasets/andradaolteanu/gtzan-dataset-music-genre-classification/ (accessed 2023-10-20)

In [1]:
from functools import partial
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score
from skorch import NeuralNetClassifier
from tqdm import tqdm

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
import numpy as np
from transformers import pipeline

# audio @16kHz
audio = np.random.randn(30 * 16000)

pipe = pipeline("audio-classification", model="mtg-upf/discogs-maest-30s-pw-73e-ts")
pipe(audio)

[{'score': 0.08050668239593506, 'label': 'Non-Music---Field Recording'},
 {'score': 0.06392410397529602, 'label': 'Electronic---Noise'},
 {'score': 0.03836221247911453, 'label': 'Electronic---Experimental'},
 {'score': 0.03440581634640694, 'label': 'Electronic---Glitch'},
 {'score': 0.02098962664604187, 'label': 'Non-Music---Political'}]

In [3]:
pipe(np.random.rand(10000000))

[{'score': 0.11369985342025757, 'label': 'Electronic---Noise'},
 {'score': 0.09564154595136642, 'label': 'Electronic---Experimental'},
 {'score': 0.056505050510168076, 'label': 'Non-Music---Field Recording'},
 {'score': 0.05239567905664444, 'label': 'Electronic---Glitch'},
 {'score': 0.0440823957324028, 'label': 'Electronic---Abstract'}]

In [4]:
pipe(np.random.rand(3))


[{'score': 0.021648230031132698, 'label': 'Electronic---House'},
 {'score': 0.013834518380463123, 'label': 'Electronic---Synth-pop'},
 {'score': 0.011615021154284477, 'label': 'Electronic---Experimental'},
 {'score': 0.011414248496294022, 'label': 'Funk / Soul---Funk'},
 {'score': 0.011334141716361046, 'label': 'Folk, World, & Country---Folk'}]

In [5]:
"""
MAEST is designed to accept data in different input formats:

1D: 16kHz audio waveform is assumed.
2D: mel-spectrogram is assumed (frequency, time).
3D: batched mel-spectrogram (batch, frequency, time).
4D: batched mel-spectrgroam plus singleton channel axis (batch, 1, frequency, time).
"""

'\nMAEST is designed to accept data in different input formats:\n\n1D: 16kHz audio waveform is assumed.\n2D: mel-spectrogram is assumed (frequency, time).\n3D: batched mel-spectrogram (batch, frequency, time).\n4D: batched mel-spectrgroam plus singleton channel axis (batch, 1, frequency, time).\n'

# Load data

In [6]:
mfcc_tensor = torch.load("mfcc.pt")
covariance_tensor =  torch.load("covariance.pt")
file_paths = np.load("file_paths.npy")
labels = np.load("labels.npy")

In [7]:
labels_to_idx = {label: idx for idx, label in enumerate(np.unique(labels))}
idx_to_labels = {idx: label for idx, label in enumerate(np.unique(labels))}
labels_to_idx

{'blues': 0,
 'classical': 1,
 'country': 2,
 'disco': 3,
 'hiphop': 4,
 'jazz': 5,
 'metal': 6,
 'pop': 7,
 'reggae': 8,
 'rock': 9}

# Preprocess data
https://github.com/palonso/MAEST/blob/main/datasets/mtt/preprocess.py 

## Train test split
Might delete this later.

In [8]:
# Reshape the data into a 2D array (num_samples, num_features)
num_samples, num_frames, num_mfcc = mfcc_tensor.shape
mfcc_tensor_2d = np.reshape(mfcc_tensor, (num_samples, num_frames * num_mfcc))

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(mfcc_tensor_2d, labels, test_size=0.2, random_state=42)

# Get validation set
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

In [9]:
X_train.shape

torch.Size([719, 38818])

In [10]:
y_train[0]

'classical'

In [11]:
X_train[0]

tensor([ 68.4585,  -8.7303, -31.3657,  ...,   4.8434,  20.1321, -17.2632])

# Test feeding model

In [94]:
file = "../../melbands.npy"
with open(file, 'rb') as f:
    blues = np.load(f, allow_pickle=True)

blues.shape

(1877, 96)

In [95]:
blues.reshape(1877* 96)

array([3.064 , 2.951 , 2.645 , ..., 0.2291, 0.2312, 0.2028], dtype=float16)

In [96]:
pipe(blues.reshape(1877* 96))

[{'score': 0.1781368851661682, 'label': 'Electronic---Noise'},
 {'score': 0.08388300985097885, 'label': 'Electronic---Experimental'},
 {'score': 0.07882019132375717, 'label': 'Electronic---Glitch'},
 {'score': 0.06479445099830627, 'label': 'Electronic---Abstract'},
 {'score': 0.053518954664468765, 'label': 'Electronic---Power Electronics'}]

In [12]:
pipe(X_train[0].numpy())

[{'score': 0.353606641292572, 'label': 'Electronic---Noise'},
 {'score': 0.04757178947329521, 'label': 'Electronic---Power Electronics'},
 {'score': 0.042233750224113464, 'label': 'Electronic---Experimental'},
 {'score': 0.03940293937921524, 'label': 'Rock---Noise'},
 {'score': 0.03595265746116638, 'label': 'Electronic---Industrial'}]

In [13]:
N = 10

for x, y in zip(X_train[:N], y_train[:N]):
    print(pipe(x.numpy()))
    print(y)
    print("-"*N)

[{'score': 0.353606641292572, 'label': 'Electronic---Noise'}, {'score': 0.04757178947329521, 'label': 'Electronic---Power Electronics'}, {'score': 0.042233750224113464, 'label': 'Electronic---Experimental'}, {'score': 0.03940293937921524, 'label': 'Rock---Noise'}, {'score': 0.03595265746116638, 'label': 'Electronic---Industrial'}]
classical
----------
[{'score': 0.30069082975387573, 'label': 'Electronic---Noise'}, {'score': 0.045320652425289154, 'label': 'Rock---Noise'}, {'score': 0.04330591484904289, 'label': 'Electronic---Power Electronics'}, {'score': 0.038458213210105896, 'label': 'Electronic---Rhythmic Noise'}, {'score': 0.037081748247146606, 'label': 'Electronic---Experimental'}]
blues
----------
[{'score': 0.3399991989135742, 'label': 'Electronic---Noise'}, {'score': 0.04738768935203552, 'label': 'Electronic---Power Electronics'}, {'score': 0.04102112725377083, 'label': 'Rock---Noise'}, {'score': 0.038403041660785675, 'label': 'Electronic---Experimental'}, {'score': 0.0318494848

Looks the model is most likely to predict "noise" for all our samples. 
This tells us that our dataset is not very similar to the dataset the model was trained on.

We can try fine-tuning the model on a subsample of our dataset. 

An alternative here, would be to tweak our pre-processing to match the pre-processing of the model.
We can try that if this first step doesn't work? 

# Fine-tune model

In [14]:
X_train[2].numpy()

array([ 74.0644  , -13.225058, -21.110668, ...,  -8.67979 , -16.61049 ,
       -16.836916], dtype=float32)

In [15]:
X_train[:10].numpy()

array([[ 68.45853  ,  -8.730311 , -31.365671 , ...,   4.843421 ,
         20.132097 , -17.263208 ],
       [ 67.08611  ,  -1.9582493,  -1.1213299, ...,  -5.0454655,
        -17.204418 , -32.76073  ],
       [ 74.0644   , -13.225058 , -21.110668 , ...,  -8.67979  ,
        -16.61049  , -16.836916 ],
       ...,
       [ 76.175125 ,   4.0349603, -24.262367 , ...,  -3.176486 ,
        -10.253227 ,  -3.1574678],
       [ 63.21377  ,  -4.3379726,  -6.844071 , ...,  36.45512  ,
          2.5168054, -12.495918 ],
       [ 66.55072  ,  -8.971389 ,  17.160826 , ...,  14.640405 ,
         16.61913  ,   9.466786 ]], dtype=float32)

# Label mappings 

In [16]:
# taken from https://github.com/palonso/MAEST
discogs_labels_full = [
    "Blues---Boogie Woogie",
    "Blues---Chicago Blues",
    "Blues---Country Blues",
    "Blues---Delta Blues",
    "Blues---Electric Blues",
    "Blues---Harmonica Blues",
    "Blues---Jump Blues",
    "Blues---Louisiana Blues",
    "Blues---Modern Electric Blues",
    "Blues---Piano Blues",
    "Blues---Rhythm & Blues",
    "Blues---Texas Blues",
    "Brass & Military---Brass Band",
    "Brass & Military---Marches",
    "Brass & Military---Military",
    "Children's---Educational",
    "Children's---Nursery Rhymes",
    "Children's---Story",
    "Classical---Baroque",
    "Classical---Choral",
    "Classical---Classical",
    "Classical---Contemporary",
    "Classical---Impressionist",
    "Classical---Medieval",
    "Classical---Modern",
    "Classical---Neo-Classical",
    "Classical---Neo-Romantic",
    "Classical---Opera",
    "Classical---Post-Modern",
    "Classical---Renaissance",
    "Classical---Romantic",
    "Electronic---Abstract",
    "Electronic---Acid",
    "Electronic---Acid House",
    "Electronic---Acid Jazz",
    "Electronic---Ambient",
    "Electronic---Bassline",
    "Electronic---Beatdown",
    "Electronic---Berlin-School",
    "Electronic---Big Beat",
    "Electronic---Bleep",
    "Electronic---Breakbeat",
    "Electronic---Breakcore",
    "Electronic---Breaks",
    "Electronic---Broken Beat",
    "Electronic---Chillwave",
    "Electronic---Chiptune",
    "Electronic---Dance-pop",
    "Electronic---Dark Ambient",
    "Electronic---Darkwave",
    "Electronic---Deep House",
    "Electronic---Deep Techno",
    "Electronic---Disco",
    "Electronic---Disco Polo",
    "Electronic---Donk",
    "Electronic---Downtempo",
    "Electronic---Drone",
    "Electronic---Drum n Bass",
    "Electronic---Dub",
    "Electronic---Dub Techno",
    "Electronic---Dubstep",
    "Electronic---Dungeon Synth",
    "Electronic---EBM",
    "Electronic---Electro",
    "Electronic---Electro House",
    "Electronic---Electroclash",
    "Electronic---Euro House",
    "Electronic---Euro-Disco",
    "Electronic---Eurobeat",
    "Electronic---Eurodance",
    "Electronic---Experimental",
    "Electronic---Freestyle",
    "Electronic---Future Jazz",
    "Electronic---Gabber",
    "Electronic---Garage House",
    "Electronic---Ghetto",
    "Electronic---Ghetto House",
    "Electronic---Glitch",
    "Electronic---Goa Trance",
    "Electronic---Grime",
    "Electronic---Halftime",
    "Electronic---Hands Up",
    "Electronic---Happy Hardcore",
    "Electronic---Hard House",
    "Electronic---Hard Techno",
    "Electronic---Hard Trance",
    "Electronic---Hardcore",
    "Electronic---Hardstyle",
    "Electronic---Hi NRG",
    "Electronic---Hip Hop",
    "Electronic---Hip-House",
    "Electronic---House",
    "Electronic---IDM",
    "Electronic---Illbient",
    "Electronic---Industrial",
    "Electronic---Italo House",
    "Electronic---Italo-Disco",
    "Electronic---Italodance",
    "Electronic---Jazzdance",
    "Electronic---Juke",
    "Electronic---Jumpstyle",
    "Electronic---Jungle",
    "Electronic---Latin",
    "Electronic---Leftfield",
    "Electronic---Makina",
    "Electronic---Minimal",
    "Electronic---Minimal Techno",
    "Electronic---Modern Classical",
    "Electronic---Musique Concr\u00e8te",
    "Electronic---Neofolk",
    "Electronic---New Age",
    "Electronic---New Beat",
    "Electronic---New Wave",
    "Electronic---Noise",
    "Electronic---Nu-Disco",
    "Electronic---Power Electronics",
    "Electronic---Progressive Breaks",
    "Electronic---Progressive House",
    "Electronic---Progressive Trance",
    "Electronic---Psy-Trance",
    "Electronic---Rhythmic Noise",
    "Electronic---Schranz",
    "Electronic---Sound Collage",
    "Electronic---Speed Garage",
    "Electronic---Speedcore",
    "Electronic---Synth-pop",
    "Electronic---Synthwave",
    "Electronic---Tech House",
    "Electronic---Tech Trance",
    "Electronic---Techno",
    "Electronic---Trance",
    "Electronic---Tribal",
    "Electronic---Tribal House",
    "Electronic---Trip Hop",
    "Electronic---Tropical House",
    "Electronic---UK Garage",
    "Electronic---Vaporwave",
    "Folk, World, & Country---African",
    "Folk, World, & Country---Bluegrass",
    "Folk, World, & Country---Cajun",
    "Folk, World, & Country---Canzone Napoletana",
    "Folk, World, & Country---Catalan Music",
    "Folk, World, & Country---Celtic",
    "Folk, World, & Country---Country",
    "Folk, World, & Country---Fado",
    "Folk, World, & Country---Flamenco",
    "Folk, World, & Country---Folk",
    "Folk, World, & Country---Gospel",
    "Folk, World, & Country---Highlife",
    "Folk, World, & Country---Hillbilly",
    "Folk, World, & Country---Hindustani",
    "Folk, World, & Country---Honky Tonk",
    "Folk, World, & Country---Indian Classical",
    "Folk, World, & Country---La\u00efk\u00f3",
    "Folk, World, & Country---Nordic",
    "Folk, World, & Country---Pacific",
    "Folk, World, & Country---Polka",
    "Folk, World, & Country---Ra\u00ef",
    "Folk, World, & Country---Romani",
    "Folk, World, & Country---Soukous",
    "Folk, World, & Country---S\u00e9ga",
    "Folk, World, & Country---Volksmusik",
    "Folk, World, & Country---Zouk",
    "Folk, World, & Country---\u00c9ntekhno",
    "Funk / Soul---Afrobeat",
    "Funk / Soul---Boogie",
    "Funk / Soul---Contemporary R&B",
    "Funk / Soul---Disco",
    "Funk / Soul---Free Funk",
    "Funk / Soul---Funk",
    "Funk / Soul---Gospel",
    "Funk / Soul---Neo Soul",
    "Funk / Soul---New Jack Swing",
    "Funk / Soul---P.Funk",
    "Funk / Soul---Psychedelic",
    "Funk / Soul---Rhythm & Blues",
    "Funk / Soul---Soul",
    "Funk / Soul---Swingbeat",
    "Funk / Soul---UK Street Soul",
    "Hip Hop---Bass Music",
    "Hip Hop---Boom Bap",
    "Hip Hop---Bounce",
    "Hip Hop---Britcore",
    "Hip Hop---Cloud Rap",
    "Hip Hop---Conscious",
    "Hip Hop---Crunk",
    "Hip Hop---Cut-up/DJ",
    "Hip Hop---DJ Battle Tool",
    "Hip Hop---Electro",
    "Hip Hop---G-Funk",
    "Hip Hop---Gangsta",
    "Hip Hop---Grime",
    "Hip Hop---Hardcore Hip-Hop",
    "Hip Hop---Horrorcore",
    "Hip Hop---Instrumental",
    "Hip Hop---Jazzy Hip-Hop",
    "Hip Hop---Miami Bass",
    "Hip Hop---Pop Rap",
    "Hip Hop---Ragga HipHop",
    "Hip Hop---RnB/Swing",
    "Hip Hop---Screw",
    "Hip Hop---Thug Rap",
    "Hip Hop---Trap",
    "Hip Hop---Trip Hop",
    "Hip Hop---Turntablism",
    "Jazz---Afro-Cuban Jazz",
    "Jazz---Afrobeat",
    "Jazz---Avant-garde Jazz",
    "Jazz---Big Band",
    "Jazz---Bop",
    "Jazz---Bossa Nova",
    "Jazz---Contemporary Jazz",
    "Jazz---Cool Jazz",
    "Jazz---Dixieland",
    "Jazz---Easy Listening",
    "Jazz---Free Improvisation",
    "Jazz---Free Jazz",
    "Jazz---Fusion",
    "Jazz---Gypsy Jazz",
    "Jazz---Hard Bop",
    "Jazz---Jazz-Funk",
    "Jazz---Jazz-Rock",
    "Jazz---Latin Jazz",
    "Jazz---Modal",
    "Jazz---Post Bop",
    "Jazz---Ragtime",
    "Jazz---Smooth Jazz",
    "Jazz---Soul-Jazz",
    "Jazz---Space-Age",
    "Jazz---Swing",
    "Latin---Afro-Cuban",
    "Latin---Bai\u00e3o",
    "Latin---Batucada",
    "Latin---Beguine",
    "Latin---Bolero",
    "Latin---Boogaloo",
    "Latin---Bossanova",
    "Latin---Cha-Cha",
    "Latin---Charanga",
    "Latin---Compas",
    "Latin---Cubano",
    "Latin---Cumbia",
    "Latin---Descarga",
    "Latin---Forr\u00f3",
    "Latin---Guaguanc\u00f3",
    "Latin---Guajira",
    "Latin---Guaracha",
    "Latin---MPB",
    "Latin---Mambo",
    "Latin---Mariachi",
    "Latin---Merengue",
    "Latin---Norte\u00f1o",
    "Latin---Nueva Cancion",
    "Latin---Pachanga",
    "Latin---Porro",
    "Latin---Ranchera",
    "Latin---Reggaeton",
    "Latin---Rumba",
    "Latin---Salsa",
    "Latin---Samba",
    "Latin---Son",
    "Latin---Son Montuno",
    "Latin---Tango",
    "Latin---Tejano",
    "Latin---Vallenato",
    "Non-Music---Audiobook",
    "Non-Music---Comedy",
    "Non-Music---Dialogue",
    "Non-Music---Education",
    "Non-Music---Field Recording",
    "Non-Music---Interview",
    "Non-Music---Monolog",
    "Non-Music---Poetry",
    "Non-Music---Political",
    "Non-Music---Promotional",
    "Non-Music---Radioplay",
    "Non-Music---Religious",
    "Non-Music---Spoken Word",
    "Pop---Ballad",
    "Pop---Bollywood",
    "Pop---Bubblegum",
    "Pop---Chanson",
    "Pop---City Pop",
    "Pop---Europop",
    "Pop---Indie Pop",
    "Pop---J-pop",
    "Pop---K-pop",
    "Pop---Kay\u014dkyoku",
    "Pop---Light Music",
    "Pop---Music Hall",
    "Pop---Novelty",
    "Pop---Parody",
    "Pop---Schlager",
    "Pop---Vocal",
    "Reggae---Calypso",
    "Reggae---Dancehall",
    "Reggae---Dub",
    "Reggae---Lovers Rock",
    "Reggae---Ragga",
    "Reggae---Reggae",
    "Reggae---Reggae-Pop",
    "Reggae---Rocksteady",
    "Reggae---Roots Reggae",
    "Reggae---Ska",
    "Reggae---Soca",
    "Rock---AOR",
    "Rock---Acid Rock",
    "Rock---Acoustic",
    "Rock---Alternative Rock",
    "Rock---Arena Rock",
    "Rock---Art Rock",
    "Rock---Atmospheric Black Metal",
    "Rock---Avantgarde",
    "Rock---Beat",
    "Rock---Black Metal",
    "Rock---Blues Rock",
    "Rock---Brit Pop",
    "Rock---Classic Rock",
    "Rock---Coldwave",
    "Rock---Country Rock",
    "Rock---Crust",
    "Rock---Death Metal",
    "Rock---Deathcore",
    "Rock---Deathrock",
    "Rock---Depressive Black Metal",
    "Rock---Doo Wop",
    "Rock---Doom Metal",
    "Rock---Dream Pop",
    "Rock---Emo",
    "Rock---Ethereal",
    "Rock---Experimental",
    "Rock---Folk Metal",
    "Rock---Folk Rock",
    "Rock---Funeral Doom Metal",
    "Rock---Funk Metal",
    "Rock---Garage Rock",
    "Rock---Glam",
    "Rock---Goregrind",
    "Rock---Goth Rock",
    "Rock---Gothic Metal",
    "Rock---Grindcore",
    "Rock---Grunge",
    "Rock---Hard Rock",
    "Rock---Hardcore",
    "Rock---Heavy Metal",
    "Rock---Indie Rock",
    "Rock---Industrial",
    "Rock---Krautrock",
    "Rock---Lo-Fi",
    "Rock---Lounge",
    "Rock---Math Rock",
    "Rock---Melodic Death Metal",
    "Rock---Melodic Hardcore",
    "Rock---Metalcore",
    "Rock---Mod",
    "Rock---Neofolk",
    "Rock---New Wave",
    "Rock---No Wave",
    "Rock---Noise",
    "Rock---Noisecore",
    "Rock---Nu Metal",
    "Rock---Oi",
    "Rock---Parody",
    "Rock---Pop Punk",
    "Rock---Pop Rock",
    "Rock---Pornogrind",
    "Rock---Post Rock",
    "Rock---Post-Hardcore",
    "Rock---Post-Metal",
    "Rock---Post-Punk",
    "Rock---Power Metal",
    "Rock---Power Pop",
    "Rock---Power Violence",
    "Rock---Prog Rock",
    "Rock---Progressive Metal",
    "Rock---Psychedelic Rock",
    "Rock---Psychobilly",
    "Rock---Pub Rock",
    "Rock---Punk",
    "Rock---Rock & Roll",
    "Rock---Rockabilly",
    "Rock---Shoegaze",
    "Rock---Ska",
    "Rock---Sludge Metal",
    "Rock---Soft Rock",
    "Rock---Southern Rock",
    "Rock---Space Rock",
    "Rock---Speed Metal",
    "Rock---Stoner Rock",
    "Rock---Surf",
    "Rock---Symphonic Rock",
    "Rock---Technical Death Metal",
    "Rock---Thrash",
    "Rock---Twist",
    "Rock---Viking Metal",
    "Rock---Y\u00e9-Y\u00e9",
    "Stage & Screen---Musical",
    "Stage & Screen---Score",
    "Stage & Screen---Soundtrack",
    "Stage & Screen---Theme"
]

def get_tier_1_labels(labels, sep="---"):
    return set([
        label.split(sep)[0]
        for label in labels
    ])

discogs_tier_1_labels = get_tier_1_labels(discogs_labels_full)
discogs_tier_1_labels

{'Blues',
 'Brass & Military',
 "Children's",
 'Classical',
 'Electronic',
 'Folk, World, & Country',
 'Funk / Soul',
 'Hip Hop',
 'Jazz',
 'Latin',
 'Non-Music',
 'Pop',
 'Reggae',
 'Rock',
 'Stage & Screen'}

In [37]:
# manually build mapping
gtzan_to_discogs = {
    "blues": "Blues",
    "classical": "Classical",
    "country": "Folk, World, & Country",
    "disco": "Funk / Soul",
    "hiphop": "Hip Hop",
    "jazz": "Jazz",
    "metal": "Rock", # Metal is a subgenre of Rock
    "pop": "Pop",
    "reggae": "Reggae",
    "rock": "Rock"
}

discogs_to_gtzan = {v: k for k, v in gtzan_to_discogs.items()}

gtzan_labels = list(gtzan_to_discogs.keys())

set(labels) - set(gtzan_to_discogs), set(gtzan_to_discogs) - set(labels)

(set(), set())

# Fine-tune model

In [18]:
mfcc_tensor.shape

torch.Size([999, 2986, 13])

In [19]:
X_train.reshape(-1, 2986, 13).shape

torch.Size([719, 2986, 13])

In [20]:
batch_2d = X_train.reshape(-1, 2986, 13)[:100].numpy()
batch_2d[0].shape

(2986, 13)

In [21]:
# pipe(batch_2d[0])

In [23]:
pipe(X_train[0].numpy())

[{'score': 0.353606641292572, 'label': 'Electronic---Noise'},
 {'score': 0.04757178947329521, 'label': 'Electronic---Power Electronics'},
 {'score': 0.042233750224113464, 'label': 'Electronic---Experimental'},
 {'score': 0.03940293937921524, 'label': 'Rock---Noise'},
 {'score': 0.03595265746116638, 'label': 'Electronic---Industrial'}]

In [26]:
pipe.model

ASTForAudioClassification(
  (audio_spectrogram_transformer): ASTModel(
    (embeddings): ASTEmbeddings(
      (patch_embeddings): ASTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ASTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ASTLayer(
          (attention): ASTAttention(
            (attention): ASTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ASTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ASTIntermediate(
            (de

In [49]:
def output_to_label(output):
    """
    Output from pipe is:
    [{'score': 0.353606641292572, 'label': 'Electronic---Noise'}, {'score': 0.04757178947329521, 'label': 'Electronic---Power Electronics'}, {'score': 0.042233750224113464, 'label': 'Electronic---Experimental'}, {'score': 0.03940293937921524, 'label': 'Rock---Noise'}, {'score': 0.03595265746116638, 'label': 'Electronic---Industrial'}]

    Need to get it to:
    """
    dlabel = output[0]["label"].split("---")[0]
    return discogs_to_gtzan.get(dlabel, "NA")


In [50]:
label_to_idx = {label: idx for idx, label in enumerate(gtzan_labels)}
idx_to_label = {idx: label for idx, label in enumerate(gtzan_labels)}

label_to_idx["NA"] = len(label_to_idx)
idx_to_label[len(idx_to_label)] = "NA"


In [67]:
def label_mapping(label):
    return label_to_idx.get(label.split("---")[0], label_to_idx["NA"])

In [69]:
# model = pipe.model
optimizer = optim.Adam(pipe.model.parameters(), lr=0.01)
criterion = nn.BCELoss()

batch_size = 50

x_batch = X_train[:batch_size]
y_batch = y_train[:batch_size]

train_loss = []

for epoch in range(3):
    batch_loss = []
    for x, y in zip(x_batch, y_batch):
        optimizer.zero_grad()

        output = pipe(x.numpy())
        # predicted_label = output_to_label(output)
        
        # predicted_label = torch.tensor([label_to_idx[predicted_label]])
        # y = torch.tensor([label_to_idx[y]])

        # print(predicted_label.shape, y.shape)
        # print(predicted_label, y)
        
        # loss = criterion(predicted_label, y)
        scores = torch.tensor([out["score"] for out in output])
        labels = torch.tensor([label_mapping(out["label"]) for out in output])

        # Apply softmax to convert scores to probabilities
        probabilities = nn.functional.softmax(scores, dim=0)

        # CrossEntropyLoss expects raw scores, not probabilities
        # So, you can use negative log likelihood directly
        loss = nn.functional.nll_loss(torch.log(probabilities), labels)

        # # Assuming y is a single label
        # y = "cat1"
        # y = torch.tensor([label_mapping[y]])

        # # Calculate loss
        # loss = nn.functional.cross_entropy(scores.view(1, -1), y)

        loss.backward()
        optimizer.step()
        
        batch_loss.append(loss.item())

    train_loss.append(np.mean(batch_loss))
    print(f"Epoch {epoch} loss: {train_loss[-1]}")
    

IndexError: Target 10 is out of bounds.

In [32]:
for y in y_batch:
    print(y)
    break

classical
