In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
import json
import os
from sklearn.model_selection import train_test_split

In [2]:
seed = 1
torch.manual_seed(seed)
chord_templates:dict = json.load(open('./chord_templates.json'))

In [3]:
class ChordDetector(Dataset):
    def __init__(self, train:bool, chord_template:dict = json.load(open('./chord_templates.json')), data_location:str = './data/', sr = 44100, hop = 256):
        super(ChordDetector, self).__init__()
        self.chord_template = chord_template
        self.data = []
        self.sr = sr
        self.hop = hop
        for file in os.listdir(data_location):
            chord_true = torch.Tensor(self.chord_template[self._extract_chord_name(file)])
            y, sr = librosa.load(data_location+file, sr = sr)
            chroma = torch.Tensor(librosa.feature.chroma_cens(y=y, sr = sr, hop_length=hop)).T
            self.data.append((chroma, chord_true))
        X_train, X_test, _, _ = train_test_split(self.data, self.data, test_size=0.2, random_state=seed)
        if train:
            self.data = X_train
        else:
            self.data = X_test

    def _extract_chord_name(self, file):
        main = file[:file.index('-')]
        if file[file.index('-')+2]=='i':
            return main+'m'
        return main
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [4]:
train_data = ChordDetector(train=True)
val_data = ChordDetector(train=False)

In [5]:
train_loader = DataLoader(
    train_data, 
    batch_size=64,
    shuffle=True
)
val_loader = DataLoader(
    val_data, 
    batch_size=64,
    shuffle=True
)

In [6]:
class GRU(nn.Module):
    def __init__(self, input_size = 12, hidden_size = 64, num_layers = 1, num_classes = 12, bidirectional = True) -> None:
        super(GRU, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first = True, bidirectional=bidirectional)
        if(bidirectional):
            self.fc = nn.Linear(hidden_size*2, num_classes)
        else:
            self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        if(self.bidirectional):
            h0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size)
        else:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.gru(x, h0)
        out = out[:,-1,:] # Since we only want the output of the last cell
        out = self.fc(out)
        return(out)

In [7]:
device = torch.device('cpu')
EPOCHS = 100

In [8]:
model = GRU().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 2e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.995)

In [9]:
def predict(model, audio, chroma_req = True, chord_templates:dict = json.load(open('./chord_templates.json')), sr = 44100, hop = 256):
    if chroma_req:
        chroma = torch.Tensor(librosa.feature.chroma_cens(y=audio, sr = sr, hop_length=hop)).T.unsqueeze(0)
    else:
        chroma = audio
    with torch.no_grad():
        outputs = nn.functional.softmax(model(chroma), 1)[0]
    min_val = 120
    min_key = ''
    for key, val in chord_templates.items():
        out = torch.norm(torch.Tensor(val) - outputs)
        if min_val >= out:
            min_val = out
            min_key = key
    return min_key

In [10]:
best_accuracy = 0
best_weights = model.state_dict()

In [11]:
for epoch in range(EPOCHS):
    model.train()
    for i, (chroma, labels) in enumerate(train_loader):
        chroma = chroma.to(device)
        labels = labels.to(device)

        outputs = model(chroma)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
    with torch.no_grad():
        model.eval()
        count = 0
        for y in tqdm(val_data):
            pred = predict(model, y[0].unsqueeze(0), False)
            min_val = 120
            min_key = ''
            for key, val in chord_templates.items():
                out = torch.norm(torch.Tensor(val) - y[1])
                if min_val >= out:
                    min_val = out
                    min_key = key
            count+=(pred==min_key)
        acc = 100*count/len(val_data)
        if best_accuracy <= acc:
            best_accuracy = acc
            best_weights = model.state_dict()
        print(f"Val Accuracy = {acc:.2f}%")
    print(f'Epoch {epoch+1}/{EPOCHS}; Loss: {loss.item():.4f}')

100%|██████████| 58/58 [00:01<00:00, 35.15it/s]


Val Accuracy = 8.62%
Epoch 1/100; Loss: 7.3584


100%|██████████| 58/58 [00:01<00:00, 35.41it/s]


Val Accuracy = 12.07%
Epoch 2/100; Loss: 7.1157


100%|██████████| 58/58 [00:01<00:00, 35.20it/s]


Val Accuracy = 12.07%
Epoch 3/100; Loss: 7.0163


100%|██████████| 58/58 [00:01<00:00, 35.21it/s]


Val Accuracy = 13.79%
Epoch 4/100; Loss: 7.2554


100%|██████████| 58/58 [00:01<00:00, 35.27it/s]


Val Accuracy = 13.79%
Epoch 5/100; Loss: 6.8297


100%|██████████| 58/58 [00:01<00:00, 35.16it/s]


Val Accuracy = 10.34%
Epoch 6/100; Loss: 6.7561


100%|██████████| 58/58 [00:01<00:00, 35.12it/s]


Val Accuracy = 15.52%
Epoch 7/100; Loss: 6.3772


100%|██████████| 58/58 [00:01<00:00, 35.30it/s]


Val Accuracy = 15.52%
Epoch 8/100; Loss: 6.6958


100%|██████████| 58/58 [00:01<00:00, 35.43it/s]


Val Accuracy = 18.97%
Epoch 9/100; Loss: 6.2860


100%|██████████| 58/58 [00:01<00:00, 34.90it/s]


Val Accuracy = 20.69%
Epoch 10/100; Loss: 6.5636


100%|██████████| 58/58 [00:01<00:00, 35.15it/s]


Val Accuracy = 18.97%
Epoch 11/100; Loss: 6.1682


100%|██████████| 58/58 [00:01<00:00, 35.25it/s]


Val Accuracy = 29.31%
Epoch 12/100; Loss: 6.4378


100%|██████████| 58/58 [00:01<00:00, 34.63it/s]


Val Accuracy = 36.21%
Epoch 13/100; Loss: 5.7291


100%|██████████| 58/58 [00:01<00:00, 35.16it/s]


Val Accuracy = 56.90%
Epoch 14/100; Loss: 5.5114


100%|██████████| 58/58 [00:01<00:00, 35.11it/s]


Val Accuracy = 60.34%
Epoch 15/100; Loss: 4.8651


100%|██████████| 58/58 [00:01<00:00, 35.00it/s]


Val Accuracy = 65.52%
Epoch 16/100; Loss: 4.6357


100%|██████████| 58/58 [00:01<00:00, 35.20it/s]


Val Accuracy = 63.79%
Epoch 17/100; Loss: 4.3387


100%|██████████| 58/58 [00:01<00:00, 34.96it/s]


Val Accuracy = 68.97%
Epoch 18/100; Loss: 4.3384


100%|██████████| 58/58 [00:01<00:00, 35.40it/s]


Val Accuracy = 79.31%
Epoch 19/100; Loss: 4.1976


100%|██████████| 58/58 [00:01<00:00, 34.88it/s]


Val Accuracy = 82.76%
Epoch 20/100; Loss: 4.1227


100%|██████████| 58/58 [00:01<00:00, 34.99it/s]


Val Accuracy = 89.66%
Epoch 21/100; Loss: 4.1136


100%|██████████| 58/58 [00:01<00:00, 35.37it/s]


Val Accuracy = 89.66%
Epoch 22/100; Loss: 3.9689


100%|██████████| 58/58 [00:01<00:00, 35.22it/s]


Val Accuracy = 91.38%
Epoch 23/100; Loss: 3.8635


100%|██████████| 58/58 [00:01<00:00, 35.54it/s]


Val Accuracy = 94.83%
Epoch 24/100; Loss: 3.7963


100%|██████████| 58/58 [00:01<00:00, 35.32it/s]


Val Accuracy = 91.38%
Epoch 25/100; Loss: 3.6618


100%|██████████| 58/58 [00:01<00:00, 35.17it/s]


Val Accuracy = 96.55%
Epoch 26/100; Loss: 3.8045


100%|██████████| 58/58 [00:01<00:00, 35.54it/s]


Val Accuracy = 94.83%
Epoch 27/100; Loss: 3.6168


100%|██████████| 58/58 [00:01<00:00, 35.26it/s]


Val Accuracy = 94.83%
Epoch 28/100; Loss: 3.5850


100%|██████████| 58/58 [00:01<00:00, 35.37it/s]


Val Accuracy = 89.66%
Epoch 29/100; Loss: 3.5750


100%|██████████| 58/58 [00:01<00:00, 35.52it/s]


Val Accuracy = 96.55%
Epoch 30/100; Loss: 3.5637


100%|██████████| 58/58 [00:01<00:00, 35.08it/s]


Val Accuracy = 93.10%
Epoch 31/100; Loss: 3.4880


100%|██████████| 58/58 [00:01<00:00, 34.99it/s]


Val Accuracy = 93.10%
Epoch 32/100; Loss: 3.4334


100%|██████████| 58/58 [00:01<00:00, 34.96it/s]


Val Accuracy = 98.28%
Epoch 33/100; Loss: 3.4473


100%|██████████| 58/58 [00:01<00:00, 35.02it/s]


Val Accuracy = 96.55%
Epoch 34/100; Loss: 3.4216


100%|██████████| 58/58 [00:01<00:00, 35.00it/s]


Val Accuracy = 96.55%
Epoch 35/100; Loss: 3.4255


100%|██████████| 58/58 [00:01<00:00, 33.74it/s]


Val Accuracy = 98.28%
Epoch 36/100; Loss: 3.4284


100%|██████████| 58/58 [00:01<00:00, 33.42it/s]


Val Accuracy = 96.55%
Epoch 37/100; Loss: 3.4163


100%|██████████| 58/58 [00:01<00:00, 33.65it/s]


Val Accuracy = 94.83%
Epoch 38/100; Loss: 3.3864


100%|██████████| 58/58 [00:01<00:00, 33.35it/s]


Val Accuracy = 94.83%
Epoch 39/100; Loss: 3.3809


100%|██████████| 58/58 [00:01<00:00, 32.98it/s]


Val Accuracy = 96.55%
Epoch 40/100; Loss: 3.3566


100%|██████████| 58/58 [00:01<00:00, 32.96it/s]


Val Accuracy = 96.55%
Epoch 41/100; Loss: 3.3544


100%|██████████| 58/58 [00:01<00:00, 33.05it/s]


Val Accuracy = 96.55%
Epoch 42/100; Loss: 3.3524


100%|██████████| 58/58 [00:01<00:00, 33.37it/s]


Val Accuracy = 96.55%
Epoch 43/100; Loss: 3.3447


100%|██████████| 58/58 [00:01<00:00, 32.85it/s]


Val Accuracy = 96.55%
Epoch 44/100; Loss: 3.3396


100%|██████████| 58/58 [00:01<00:00, 32.71it/s]


Val Accuracy = 96.55%
Epoch 45/100; Loss: 3.3371


100%|██████████| 58/58 [00:01<00:00, 33.24it/s]


Val Accuracy = 96.55%
Epoch 46/100; Loss: 3.3329


100%|██████████| 58/58 [00:01<00:00, 33.08it/s]


Val Accuracy = 96.55%
Epoch 47/100; Loss: 3.3242


100%|██████████| 58/58 [00:01<00:00, 33.20it/s]


Val Accuracy = 96.55%
Epoch 48/100; Loss: 3.3390


100%|██████████| 58/58 [00:01<00:00, 33.33it/s]


Val Accuracy = 96.55%
Epoch 49/100; Loss: 3.3262


100%|██████████| 58/58 [00:01<00:00, 33.19it/s]


Val Accuracy = 96.55%
Epoch 50/100; Loss: 3.3222


100%|██████████| 58/58 [00:01<00:00, 33.28it/s]


Val Accuracy = 96.55%
Epoch 51/100; Loss: 3.3209


100%|██████████| 58/58 [00:01<00:00, 33.09it/s]


Val Accuracy = 96.55%
Epoch 52/100; Loss: 3.3208


100%|██████████| 58/58 [00:01<00:00, 33.13it/s]


Val Accuracy = 96.55%
Epoch 53/100; Loss: 3.3211


100%|██████████| 58/58 [00:01<00:00, 32.88it/s]


Val Accuracy = 96.55%
Epoch 54/100; Loss: 3.3223


100%|██████████| 58/58 [00:01<00:00, 32.29it/s]


Val Accuracy = 96.55%
Epoch 55/100; Loss: 3.3205


100%|██████████| 58/58 [00:01<00:00, 32.34it/s]


Val Accuracy = 96.55%
Epoch 56/100; Loss: 3.3171


100%|██████████| 58/58 [00:01<00:00, 31.50it/s]


Val Accuracy = 96.55%
Epoch 57/100; Loss: 3.3230


100%|██████████| 58/58 [00:01<00:00, 31.76it/s]


Val Accuracy = 96.55%
Epoch 58/100; Loss: 3.3134


100%|██████████| 58/58 [00:01<00:00, 30.80it/s]


Val Accuracy = 96.55%
Epoch 59/100; Loss: 3.3140


100%|██████████| 58/58 [00:01<00:00, 31.86it/s]


Val Accuracy = 96.55%
Epoch 60/100; Loss: 3.3152


100%|██████████| 58/58 [00:01<00:00, 32.72it/s]


Val Accuracy = 96.55%
Epoch 61/100; Loss: 3.3127


100%|██████████| 58/58 [00:01<00:00, 33.74it/s]


Val Accuracy = 96.55%
Epoch 62/100; Loss: 3.3140


100%|██████████| 58/58 [00:01<00:00, 32.43it/s]


Val Accuracy = 96.55%
Epoch 63/100; Loss: 3.3137


100%|██████████| 58/58 [00:01<00:00, 33.33it/s]


Val Accuracy = 96.55%
Epoch 64/100; Loss: 3.3119


100%|██████████| 58/58 [00:01<00:00, 33.76it/s]


Val Accuracy = 96.55%
Epoch 65/100; Loss: 3.3116


100%|██████████| 58/58 [00:01<00:00, 33.31it/s]


Val Accuracy = 96.55%
Epoch 66/100; Loss: 3.3130


100%|██████████| 58/58 [00:01<00:00, 33.39it/s]


Val Accuracy = 96.55%
Epoch 67/100; Loss: 3.3077


100%|██████████| 58/58 [00:01<00:00, 33.29it/s]


Val Accuracy = 96.55%
Epoch 68/100; Loss: 3.3099


100%|██████████| 58/58 [00:01<00:00, 33.30it/s]


Val Accuracy = 96.55%
Epoch 69/100; Loss: 3.3114


100%|██████████| 58/58 [00:01<00:00, 33.53it/s]


Val Accuracy = 96.55%
Epoch 70/100; Loss: 3.3104


100%|██████████| 58/58 [00:01<00:00, 32.89it/s]


Val Accuracy = 96.55%
Epoch 71/100; Loss: 3.3092


100%|██████████| 58/58 [00:01<00:00, 33.19it/s]


Val Accuracy = 96.55%
Epoch 72/100; Loss: 3.3123


100%|██████████| 58/58 [00:01<00:00, 32.99it/s]


Val Accuracy = 96.55%
Epoch 73/100; Loss: 3.3108


100%|██████████| 58/58 [00:01<00:00, 33.12it/s]


Val Accuracy = 96.55%
Epoch 74/100; Loss: 3.3065


100%|██████████| 58/58 [00:01<00:00, 32.70it/s]


Val Accuracy = 96.55%
Epoch 75/100; Loss: 3.3099


100%|██████████| 58/58 [00:01<00:00, 32.56it/s]


Val Accuracy = 96.55%
Epoch 76/100; Loss: 3.3066


100%|██████████| 58/58 [00:01<00:00, 32.94it/s]


Val Accuracy = 96.55%
Epoch 77/100; Loss: 3.3089


100%|██████████| 58/58 [00:01<00:00, 32.94it/s]


Val Accuracy = 96.55%
Epoch 78/100; Loss: 3.3075


100%|██████████| 58/58 [00:01<00:00, 32.99it/s]


Val Accuracy = 96.55%
Epoch 79/100; Loss: 3.3067


100%|██████████| 58/58 [00:01<00:00, 33.76it/s]


Val Accuracy = 96.55%
Epoch 80/100; Loss: 3.3055


100%|██████████| 58/58 [00:01<00:00, 33.83it/s]


Val Accuracy = 96.55%
Epoch 81/100; Loss: 3.3057


100%|██████████| 58/58 [00:01<00:00, 33.94it/s]


Val Accuracy = 96.55%
Epoch 82/100; Loss: 3.3038


100%|██████████| 58/58 [00:01<00:00, 34.04it/s]


Val Accuracy = 96.55%
Epoch 83/100; Loss: 3.3027


100%|██████████| 58/58 [00:01<00:00, 34.07it/s]


Val Accuracy = 96.55%
Epoch 84/100; Loss: 3.3041


100%|██████████| 58/58 [00:01<00:00, 33.82it/s]


Val Accuracy = 96.55%
Epoch 85/100; Loss: 3.3033


100%|██████████| 58/58 [00:01<00:00, 33.51it/s]


Val Accuracy = 96.55%
Epoch 86/100; Loss: 3.3031


100%|██████████| 58/58 [00:01<00:00, 33.92it/s]


Val Accuracy = 96.55%
Epoch 87/100; Loss: 3.3025


100%|██████████| 58/58 [00:01<00:00, 34.09it/s]


Val Accuracy = 96.55%
Epoch 88/100; Loss: 3.3021


100%|██████████| 58/58 [00:01<00:00, 32.69it/s]


Val Accuracy = 96.55%
Epoch 89/100; Loss: 3.3021


100%|██████████| 58/58 [00:01<00:00, 32.57it/s]


Val Accuracy = 96.55%
Epoch 90/100; Loss: 3.3024


100%|██████████| 58/58 [00:01<00:00, 32.60it/s]


Val Accuracy = 96.55%
Epoch 91/100; Loss: 3.3014


100%|██████████| 58/58 [00:01<00:00, 32.63it/s]


Val Accuracy = 96.55%
Epoch 92/100; Loss: 3.3018


100%|██████████| 58/58 [00:01<00:00, 32.62it/s]


Val Accuracy = 96.55%
Epoch 93/100; Loss: 3.3020


100%|██████████| 58/58 [00:01<00:00, 33.11it/s]


Val Accuracy = 96.55%
Epoch 94/100; Loss: 3.3017


100%|██████████| 58/58 [00:01<00:00, 33.13it/s]


Val Accuracy = 96.55%
Epoch 95/100; Loss: 3.3031


100%|██████████| 58/58 [00:01<00:00, 33.01it/s]


Val Accuracy = 96.55%
Epoch 96/100; Loss: 3.3027


100%|██████████| 58/58 [00:01<00:00, 32.66it/s]


Val Accuracy = 96.55%
Epoch 97/100; Loss: 3.3015


100%|██████████| 58/58 [00:01<00:00, 32.93it/s]


Val Accuracy = 96.55%
Epoch 98/100; Loss: 3.3018


100%|██████████| 58/58 [00:01<00:00, 33.01it/s]


Val Accuracy = 96.55%
Epoch 99/100; Loss: 3.3014


100%|██████████| 58/58 [00:01<00:00, 32.93it/s]

Val Accuracy = 96.55%
Epoch 100/100; Loss: 3.3012





In [18]:
model.load_state_dict(best_weights)
model.eval()

GRU(
  (gru): GRU(12, 64, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=128, out_features=12, bias=True)
)

In [19]:
y = librosa.load('./data/F#-Minor-1.wav')[0]
predict(model, y)

'F#m'

In [20]:
def extract_chord_name(file):
    main = file[:file.index('-')]
    if file[file.index('-')+2]=='i':
        return main+'m'
    return main

In [22]:
count = 0
for y in tqdm(val_data):
    model.eval()
    pred = predict(model, y[0].unsqueeze(0), False)
    min_val = 120
    min_key = ''
    for key, val in chord_templates.items():
        out = torch.norm(torch.Tensor(val) - y[1])
        if min_val >= out:
            min_val = out
            min_key = key
    count+=(pred==min_key)
print(f"Accuracy = {100*count/len(val_data):.2f}%")

100%|██████████| 58/58 [00:01<00:00, 32.32it/s]

Accuracy = 96.55%





In [23]:
torch.save(model.state_dict(), './models/chord_detector.pth')