In [1]:
import numpy as np
import torch
import copy
from torch.utils.data import Dataset, DataLoader
from torch import nn
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
import json
import os
from sklearn.model_selection import train_test_split

In [2]:
seed = 1
torch.manual_seed(seed)
chord_templates:dict = json.load(open('./chord_templates.json'))

In [3]:
class ChordDetector(Dataset):
    def __init__(self, train:bool, chord_template:dict = json.load(open('./chord_templates.json')), data_location:str = './data/', sr = 44100, hop = 256):
        super(ChordDetector, self).__init__()
        self.chord_template = chord_template
        self.data = []
        self.sr = sr
        self.hop = hop
        for file in os.listdir(data_location):
            chord_true = torch.Tensor(self.chord_template[self._extract_chord_name(file)])
            y, sr = librosa.load(data_location+file, sr = sr)
            chroma = torch.Tensor(librosa.feature.chroma_cens(y=y, sr = sr, hop_length=hop)).T
            self.data.append((chroma, chord_true))
        X_train, X_test, _, _ = train_test_split(self.data, self.data, test_size=0.2, random_state=seed)
        if train:
            self.data = X_train
        else:
            self.data = X_test

    def _extract_chord_name(self, file):
        main = file[:file.index('-')]
        if file[file.index('-')+1]=='D':
            return main+'dim'
        elif file[file.index('-')+1]=='A':
            return main+'aug'
        elif file[file.index('-')+2]=='i':
            return main+'m'
        return main
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [4]:
train_data = ChordDetector(train=True)
val_data = ChordDetector(train=False)

In [5]:
train_loader = DataLoader(
    train_data, 
    batch_size=64,
    shuffle=True
)
val_loader = DataLoader(
    val_data, 
    batch_size=64,
    shuffle=True
)

In [6]:
class GRU(nn.Module):
    def __init__(self, input_size = 12, hidden_size = 128, num_layers = 1, num_classes = 12, bidirectional = True) -> None:
        super(GRU, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first = True, bidirectional=bidirectional)
        if(bidirectional):
            self.fc = nn.Linear(hidden_size*2, num_classes)
        else:
            self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        if(self.bidirectional):
            h0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size)
        else:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.gru(x, h0)
        out = out[:,-1,:] # Since we only want the output of the last cell
        out = self.fc(out)
        return(out)

In [7]:
device = torch.device('cpu')
EPOCHS = 100

In [8]:
model = GRU().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 2e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.995)

In [9]:
def predict(model, audio, chroma_req = True, chord_templates:dict = json.load(open('./chord_templates.json')), sr = 44100, hop = 256):
    if chroma_req:
        chroma = torch.Tensor(librosa.feature.chroma_cens(y=audio, sr = sr, hop_length=hop)).T.unsqueeze(0)
    else:
        chroma = audio
    with torch.no_grad():
        outputs = nn.functional.softmax(model(chroma), 1)[0]
    min_val = 120
    min_key = ''
    for key, val in chord_templates.items():
        out = torch.norm(torch.Tensor(val) - outputs)
        if min_val >= out:
            min_val = out
            min_key = key
    return min_key

In [16]:
best_accuracy = 0
best_weights = copy.deepcopy(model.state_dict())

In [17]:
for epoch in range(EPOCHS):
    model.train()
    for i, (chroma, labels) in enumerate(train_loader):
        chroma = chroma.to(device)
        labels = labels.to(device)

        outputs = model(chroma)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
    with torch.no_grad():
        model.eval()
        count = 0
        for y in tqdm(val_data):
            pred = predict(model, y[0].unsqueeze(0), False)
            min_val = 120
            min_key = ''
            for key, val in chord_templates.items():
                out = torch.norm(torch.Tensor(val) - y[1])
                if min_val >= out:
                    min_val = out
                    min_key = key
            count+=(pred==min_key)
        acc = 100*count/len(val_data)
        if best_accuracy <= acc:
            best_accuracy = acc
            best_weights = copy.deepcopy(model.state_dict())
        print(f"Val Accuracy = {acc:.2f}%")
    print(f'Epoch {epoch+1}/{EPOCHS}; Loss: {loss.item():.4f}')

100%|██████████| 144/144 [00:05<00:00, 25.97it/s]


Val Accuracy = 11.81%
Epoch 1/100; Loss: 7.6251


100%|██████████| 144/144 [00:05<00:00, 26.00it/s]


Val Accuracy = 8.33%
Epoch 2/100; Loss: 7.6706


100%|██████████| 144/144 [00:05<00:00, 25.82it/s]


Val Accuracy = 9.03%
Epoch 3/100; Loss: 7.5646


100%|██████████| 144/144 [00:05<00:00, 25.70it/s]


Val Accuracy = 13.19%
Epoch 4/100; Loss: 7.0613


100%|██████████| 144/144 [00:05<00:00, 25.93it/s]


Val Accuracy = 11.81%
Epoch 5/100; Loss: 7.0231


100%|██████████| 144/144 [00:05<00:00, 25.85it/s]


Val Accuracy = 15.28%
Epoch 6/100; Loss: 7.0821


100%|██████████| 144/144 [00:05<00:00, 25.79it/s]


Val Accuracy = 16.67%
Epoch 7/100; Loss: 6.7899


100%|██████████| 144/144 [00:05<00:00, 25.85it/s]


Val Accuracy = 13.89%
Epoch 8/100; Loss: 6.9489


100%|██████████| 144/144 [00:05<00:00, 25.67it/s]


Val Accuracy = 13.89%
Epoch 9/100; Loss: 6.7855


100%|██████████| 144/144 [00:05<00:00, 25.81it/s]


Val Accuracy = 18.06%
Epoch 10/100; Loss: 6.8985


100%|██████████| 144/144 [00:05<00:00, 25.76it/s]


Val Accuracy = 16.67%
Epoch 11/100; Loss: 7.0006


100%|██████████| 144/144 [00:05<00:00, 25.75it/s]


Val Accuracy = 17.36%
Epoch 12/100; Loss: 6.7416


100%|██████████| 144/144 [00:05<00:00, 25.75it/s]


Val Accuracy = 15.97%
Epoch 13/100; Loss: 6.2694


100%|██████████| 144/144 [00:05<00:00, 25.76it/s]


Val Accuracy = 22.22%
Epoch 14/100; Loss: 6.0063


100%|██████████| 144/144 [00:05<00:00, 25.76it/s]


Val Accuracy = 21.53%
Epoch 15/100; Loss: 5.8597


100%|██████████| 144/144 [00:05<00:00, 25.82it/s]


Val Accuracy = 18.75%
Epoch 16/100; Loss: 5.9380


100%|██████████| 144/144 [00:05<00:00, 25.68it/s]


Val Accuracy = 38.89%
Epoch 17/100; Loss: 5.5931


100%|██████████| 144/144 [00:05<00:00, 25.74it/s]


Val Accuracy = 46.53%
Epoch 18/100; Loss: 5.2663


100%|██████████| 144/144 [00:05<00:00, 25.74it/s]


Val Accuracy = 47.22%
Epoch 19/100; Loss: 4.7882


100%|██████████| 144/144 [00:05<00:00, 25.76it/s]


Val Accuracy = 55.56%
Epoch 20/100; Loss: 5.0661


100%|██████████| 144/144 [00:05<00:00, 25.75it/s]


Val Accuracy = 65.28%
Epoch 21/100; Loss: 4.5454


100%|██████████| 144/144 [00:05<00:00, 25.68it/s]


Val Accuracy = 64.58%
Epoch 22/100; Loss: 4.5341


100%|██████████| 144/144 [00:05<00:00, 25.72it/s]


Val Accuracy = 76.39%
Epoch 23/100; Loss: 4.1901


100%|██████████| 144/144 [00:05<00:00, 25.72it/s]


Val Accuracy = 79.17%
Epoch 24/100; Loss: 4.3729


100%|██████████| 144/144 [00:05<00:00, 25.75it/s]


Val Accuracy = 78.47%
Epoch 25/100; Loss: 4.0240


100%|██████████| 144/144 [00:05<00:00, 25.77it/s]


Val Accuracy = 85.42%
Epoch 26/100; Loss: 3.9777


100%|██████████| 144/144 [00:05<00:00, 25.80it/s]


Val Accuracy = 82.64%
Epoch 27/100; Loss: 3.8577


100%|██████████| 144/144 [00:05<00:00, 25.80it/s]


Val Accuracy = 86.81%
Epoch 28/100; Loss: 3.8108


100%|██████████| 144/144 [00:05<00:00, 25.87it/s]


Val Accuracy = 86.81%
Epoch 29/100; Loss: 3.8150


100%|██████████| 144/144 [00:05<00:00, 25.77it/s]


Val Accuracy = 83.33%
Epoch 30/100; Loss: 3.7705


100%|██████████| 144/144 [00:05<00:00, 25.71it/s]


Val Accuracy = 87.50%
Epoch 31/100; Loss: 3.8227


100%|██████████| 144/144 [00:05<00:00, 25.77it/s]


Val Accuracy = 88.19%
Epoch 32/100; Loss: 3.5973


100%|██████████| 144/144 [00:05<00:00, 25.50it/s]


Val Accuracy = 89.58%
Epoch 33/100; Loss: 3.6910


100%|██████████| 144/144 [00:05<00:00, 26.02it/s]


Val Accuracy = 93.06%
Epoch 34/100; Loss: 3.5951


100%|██████████| 144/144 [00:05<00:00, 25.98it/s]


Val Accuracy = 90.97%
Epoch 35/100; Loss: 3.7054


100%|██████████| 144/144 [00:05<00:00, 25.75it/s]


Val Accuracy = 92.36%
Epoch 36/100; Loss: 3.7735


100%|██████████| 144/144 [00:05<00:00, 25.72it/s]


Val Accuracy = 92.36%
Epoch 37/100; Loss: 3.5438


100%|██████████| 144/144 [00:05<00:00, 25.96it/s]


Val Accuracy = 90.28%
Epoch 38/100; Loss: 3.6157


100%|██████████| 144/144 [00:05<00:00, 26.32it/s]


Val Accuracy = 92.36%
Epoch 39/100; Loss: 3.5037


100%|██████████| 144/144 [00:05<00:00, 25.78it/s]


Val Accuracy = 90.97%
Epoch 40/100; Loss: 3.5366


100%|██████████| 144/144 [00:05<00:00, 25.80it/s]


Val Accuracy = 79.17%
Epoch 41/100; Loss: 4.0286


100%|██████████| 144/144 [00:05<00:00, 25.73it/s]


Val Accuracy = 88.89%
Epoch 42/100; Loss: 3.6941


100%|██████████| 144/144 [00:05<00:00, 25.71it/s]


Val Accuracy = 84.72%
Epoch 43/100; Loss: 4.0582


100%|██████████| 144/144 [00:05<00:00, 25.82it/s]


Val Accuracy = 87.50%
Epoch 44/100; Loss: 3.7957


100%|██████████| 144/144 [00:05<00:00, 25.80it/s]


Val Accuracy = 88.89%
Epoch 45/100; Loss: 3.6122


100%|██████████| 144/144 [00:05<00:00, 26.21it/s]


Val Accuracy = 92.36%
Epoch 46/100; Loss: 3.6093


100%|██████████| 144/144 [00:05<00:00, 25.69it/s]


Val Accuracy = 91.67%
Epoch 47/100; Loss: 3.4888


100%|██████████| 144/144 [00:05<00:00, 26.12it/s]


Val Accuracy = 95.83%
Epoch 48/100; Loss: 3.5144


100%|██████████| 144/144 [00:05<00:00, 26.35it/s]


Val Accuracy = 94.44%
Epoch 49/100; Loss: 3.4864


100%|██████████| 144/144 [00:05<00:00, 25.82it/s]


Val Accuracy = 95.83%
Epoch 50/100; Loss: 3.4763


100%|██████████| 144/144 [00:05<00:00, 26.19it/s]


Val Accuracy = 94.44%
Epoch 51/100; Loss: 3.4275


100%|██████████| 144/144 [00:05<00:00, 25.77it/s]


Val Accuracy = 95.83%
Epoch 52/100; Loss: 3.4266


100%|██████████| 144/144 [00:05<00:00, 25.78it/s]


Val Accuracy = 94.44%
Epoch 53/100; Loss: 3.4717


100%|██████████| 144/144 [00:05<00:00, 26.03it/s]


Val Accuracy = 94.44%
Epoch 54/100; Loss: 3.4649


100%|██████████| 144/144 [00:05<00:00, 25.78it/s]


Val Accuracy = 93.06%
Epoch 55/100; Loss: 3.4238


100%|██████████| 144/144 [00:05<00:00, 25.71it/s]


Val Accuracy = 93.75%
Epoch 56/100; Loss: 3.5178


100%|██████████| 144/144 [00:05<00:00, 25.88it/s]


Val Accuracy = 86.81%
Epoch 57/100; Loss: 3.7711


100%|██████████| 144/144 [00:05<00:00, 25.97it/s]


Val Accuracy = 88.19%
Epoch 58/100; Loss: 3.7505


100%|██████████| 144/144 [00:05<00:00, 25.63it/s]


Val Accuracy = 85.42%
Epoch 59/100; Loss: 3.6126


100%|██████████| 144/144 [00:05<00:00, 25.64it/s]


Val Accuracy = 92.36%
Epoch 60/100; Loss: 3.6901


100%|██████████| 144/144 [00:05<00:00, 26.04it/s]


Val Accuracy = 93.75%
Epoch 61/100; Loss: 3.5134


100%|██████████| 144/144 [00:05<00:00, 25.75it/s]


Val Accuracy = 94.44%
Epoch 62/100; Loss: 3.4999


100%|██████████| 144/144 [00:05<00:00, 26.11it/s]


Val Accuracy = 93.75%
Epoch 63/100; Loss: 3.4112


100%|██████████| 144/144 [00:05<00:00, 26.20it/s]


Val Accuracy = 94.44%
Epoch 64/100; Loss: 3.4527


100%|██████████| 144/144 [00:05<00:00, 25.95it/s]


Val Accuracy = 94.44%
Epoch 65/100; Loss: 3.3881


100%|██████████| 144/144 [00:05<00:00, 26.26it/s]


Val Accuracy = 92.36%
Epoch 66/100; Loss: 3.4056


100%|██████████| 144/144 [00:05<00:00, 26.27it/s]


Val Accuracy = 93.06%
Epoch 67/100; Loss: 3.3688


100%|██████████| 144/144 [00:05<00:00, 26.12it/s]


Val Accuracy = 93.75%
Epoch 68/100; Loss: 3.3482


100%|██████████| 144/144 [00:05<00:00, 26.29it/s]


Val Accuracy = 93.75%
Epoch 69/100; Loss: 3.3428


100%|██████████| 144/144 [00:05<00:00, 26.25it/s]


Val Accuracy = 94.44%
Epoch 70/100; Loss: 3.3452


100%|██████████| 144/144 [00:05<00:00, 25.78it/s]


Val Accuracy = 93.06%
Epoch 71/100; Loss: 3.3283


100%|██████████| 144/144 [00:05<00:00, 25.85it/s]


Val Accuracy = 93.06%
Epoch 72/100; Loss: 3.3346


100%|██████████| 144/144 [00:05<00:00, 25.80it/s]


Val Accuracy = 93.06%
Epoch 73/100; Loss: 3.3545


100%|██████████| 144/144 [00:05<00:00, 25.70it/s]


Val Accuracy = 93.75%
Epoch 74/100; Loss: 3.3390


100%|██████████| 144/144 [00:05<00:00, 25.73it/s]


Val Accuracy = 93.06%
Epoch 75/100; Loss: 3.3183


100%|██████████| 144/144 [00:05<00:00, 25.78it/s]


Val Accuracy = 93.06%
Epoch 76/100; Loss: 3.3184


100%|██████████| 144/144 [00:05<00:00, 25.82it/s]


Val Accuracy = 93.06%
Epoch 77/100; Loss: 3.3201


100%|██████████| 144/144 [00:05<00:00, 25.77it/s]


Val Accuracy = 95.14%
Epoch 78/100; Loss: 3.3152


100%|██████████| 144/144 [00:05<00:00, 25.63it/s]


Val Accuracy = 93.06%
Epoch 79/100; Loss: 3.3168


100%|██████████| 144/144 [00:05<00:00, 25.68it/s]


Val Accuracy = 93.75%
Epoch 80/100; Loss: 3.3507


100%|██████████| 144/144 [00:05<00:00, 25.67it/s]


Val Accuracy = 93.06%
Epoch 81/100; Loss: 3.3137


100%|██████████| 144/144 [00:05<00:00, 25.75it/s]


Val Accuracy = 93.75%
Epoch 82/100; Loss: 3.3129


100%|██████████| 144/144 [00:05<00:00, 25.71it/s]


Val Accuracy = 93.06%
Epoch 83/100; Loss: 3.3360


100%|██████████| 144/144 [00:05<00:00, 25.83it/s]


Val Accuracy = 93.06%
Epoch 84/100; Loss: 3.3496


100%|██████████| 144/144 [00:05<00:00, 25.75it/s]


Val Accuracy = 93.06%
Epoch 85/100; Loss: 3.3114


100%|██████████| 144/144 [00:05<00:00, 25.70it/s]


Val Accuracy = 93.75%
Epoch 86/100; Loss: 3.3102


100%|██████████| 144/144 [00:05<00:00, 25.84it/s]


Val Accuracy = 93.06%
Epoch 87/100; Loss: 3.3109


100%|██████████| 144/144 [00:05<00:00, 25.68it/s]


Val Accuracy = 94.44%
Epoch 88/100; Loss: 3.3110


100%|██████████| 144/144 [00:05<00:00, 25.76it/s]


Val Accuracy = 93.75%
Epoch 89/100; Loss: 3.3112


100%|██████████| 144/144 [00:05<00:00, 25.77it/s]


Val Accuracy = 93.75%
Epoch 90/100; Loss: 3.3246


100%|██████████| 144/144 [00:05<00:00, 25.82it/s]


Val Accuracy = 94.44%
Epoch 91/100; Loss: 3.3099


100%|██████████| 144/144 [00:05<00:00, 25.57it/s]


Val Accuracy = 93.06%
Epoch 92/100; Loss: 3.3115


100%|██████████| 144/144 [00:05<00:00, 25.69it/s]


Val Accuracy = 93.75%
Epoch 93/100; Loss: 3.3252


100%|██████████| 144/144 [00:05<00:00, 25.81it/s]


Val Accuracy = 93.75%
Epoch 94/100; Loss: 3.3341


100%|██████████| 144/144 [00:05<00:00, 25.67it/s]


Val Accuracy = 93.75%
Epoch 95/100; Loss: 3.3118


100%|██████████| 144/144 [00:05<00:00, 25.72it/s]


Val Accuracy = 93.75%
Epoch 96/100; Loss: 3.3326


100%|██████████| 144/144 [00:05<00:00, 25.72it/s]


Val Accuracy = 94.44%
Epoch 97/100; Loss: 3.3092


100%|██████████| 144/144 [00:05<00:00, 25.79it/s]


Val Accuracy = 94.44%
Epoch 98/100; Loss: 3.3063


100%|██████████| 144/144 [00:05<00:00, 25.78it/s]


Val Accuracy = 93.75%
Epoch 99/100; Loss: 3.3047


100%|██████████| 144/144 [00:05<00:00, 25.67it/s]

Val Accuracy = 94.44%
Epoch 100/100; Loss: 3.3221





In [18]:
model.load_state_dict(best_weights)
model.eval()

GRU(
  (gru): GRU(12, 128, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=12, bias=True)
)

In [19]:
y = librosa.load('./data/F#-Dim-1.wav')[0]
predict(model, y)

'F#dim'

In [20]:
y = librosa.load('./data/A#-Aug-1.wav')[0]
predict(model, y)

'F#aug'

In [13]:
count = 0
with torch.no_grad():
    for y in tqdm(val_data):
        model.eval()
        pred = predict(model, y[0].unsqueeze(0), False)
        min_val = 120
        min_key = ''
        for key, val in chord_templates.items():
            out = torch.norm(torch.Tensor(val) - y[1])
            if min_val >= out:
                min_val = out
                min_key = key
        count+=(pred==min_key)
print(f"Val Accuracy = {100*count/len(val_data):.2f}%")

100%|██████████| 144/144 [00:05<00:00, 27.34it/s]

Val Accuracy = 95.83%





In [24]:
torch.save(model.state_dict(), './models/chord_detector.pth')

In [10]:
model = GRU()
model.load_state_dict(torch.load('./models/chord_detector.pth'))
model.eval()

GRU(
  (gru): GRU(12, 128, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=12, bias=True)
)

In [18]:
y = librosa.load('./data/G-Minor-12.wav')[0]
predict(model, y)

'Gm'

In [12]:
count = 0
with torch.no_grad():
    for y in tqdm(train_data):
        model.eval()
        pred = predict(model, y[0].unsqueeze(0), False)
        min_val = 120
        min_key = ''
        for key, val in chord_templates.items():
            out = torch.norm(torch.Tensor(val) - y[1])
            if min_val >= out:
                min_val = out
                min_key = key
        count+=(pred==min_key)
print(f"Train Accuracy = {100*count/len(train_data):.2f}%")

100%|██████████| 576/576 [00:21<00:00, 27.37it/s]

Train Accuracy = 98.44%



