In [1]:
import numpy as np
import torch
import copy
from torch.utils.data import Dataset, DataLoader
from torch import nn
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
import json
import os
from sklearn.model_selection import train_test_split
from scipy.signal.windows import hann

In [3]:
seed = 1
torch.manual_seed(seed)
chord_templates:dict = json.load(open('./chord_templates.json'))
HOP = 256
SR = 44100
WIN_LENGTH = 2048
WINDOW = hann(WIN_LENGTH)

In [4]:
class ChordDetector(Dataset):
    def __init__(self, train:bool, chord_template:dict = json.load(open('./chord_templates.json')), data_location:str = './data/', sr = SR, hop = HOP):
        super(ChordDetector, self).__init__()
        self.chord_template = chord_template
        self.data = []
        self.sr = sr
        self.hop = hop
        for file in os.listdir(data_location):
            chord_true = torch.Tensor(self.chord_template[self._extract_chord_name(file)])
            y, sr = librosa.load(data_location+file, sr = sr)
            y = y[:sr*2]
            chroma = torch.Tensor(librosa.feature.chroma_cens(y=y, sr = sr, hop_length=hop)).T
            self.data.append((chroma, chord_true))
        X_train, X_test, _, _ = train_test_split(self.data, self.data, test_size=0.2, random_state=seed)
        if train:
            self.data = X_train
        else:
            self.data = X_test

    def _extract_chord_name(self, file):
        main = file[:file.index('-')]
        if file[file.index('-')+1]=='D':
            return main+'dim'
        elif file[file.index('-')+1]=='A':
            return main+'aug'
        elif file[file.index('-')+2]=='i':
            return main+'m'
        return main
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [5]:
train_data = ChordDetector(train=True)
val_data = ChordDetector(train=False)



In [6]:
train_loader = DataLoader(
    train_data, 
    batch_size=64,
    shuffle=True
)
val_loader = DataLoader(
    val_data, 
    batch_size=64,
    shuffle=True
)

In [18]:
class GRU(nn.Module):
    def __init__(self, input_size = 12, hidden_size = 256, num_layers = 2, num_classes = 12, bidirectional = True) -> None:
        super(GRU, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first = True, bidirectional=bidirectional)
        if(bidirectional):
            self.fc = nn.Linear(hidden_size*2, num_classes)
        else:
            self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        if(self.bidirectional):
            h0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size)
        else:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.gru(x, h0)
        out = out[:,-1,:] # Since we only want the output of the last cell
        out = self.fc(out)
        return(out)

In [19]:
device = torch.device('cpu')
EPOCHS = 100

In [20]:
model = GRU().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.995)

In [21]:
def predict(model, audio, chroma_req = True, chord_templates:dict = json.load(open('./chord_templates.json')), sr = SR, hop = HOP):
    if chroma_req:
        chroma = torch.Tensor(librosa.feature.chroma_cens(y=audio, sr = sr, hop_length=hop)).T.unsqueeze(0)
    else:
        chroma = audio
    with torch.no_grad():
        outputs = nn.functional.softmax(model(chroma), 1)[0]
    min_val = 120
    min_key = ''
    for key, val in chord_templates.items():
        out = torch.norm(torch.Tensor(val) - outputs)
        if min_val >= out:
            min_val = out
            min_key = key
    return min_key

In [22]:
best_accuracy = 0
best_weights = copy.deepcopy(model.state_dict())

In [23]:
for epoch in range(EPOCHS):
    model.train()
    for i, (chroma, labels) in enumerate(train_loader):
        chroma = chroma.to(device)
        labels = labels.to(device)

        outputs = model(chroma)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
    with torch.no_grad():
        model.eval()
        count = 0
        for y in tqdm(val_data):
            pred = predict(model, y[0].unsqueeze(0), False)
            min_val = 120
            min_key = ''
            for key, val in chord_templates.items():
                out = torch.norm(torch.Tensor(val) - y[1])
                if min_val >= out:
                    min_val = out
                    min_key = key
            count+=(pred==min_key)
        acc = 100*count/len(val_data)
        if best_accuracy <= acc:
            best_accuracy = acc
            best_weights = copy.deepcopy(model.state_dict())
        print(f"Val Accuracy = {acc:.2f}%")
    print(f'Epoch {epoch+1}/{EPOCHS}; Loss: {loss.item():.4f}')

100%|██████████| 144/144 [00:04<00:00, 30.41it/s]


Val Accuracy = 51.39%
Epoch 1/100; Loss: 5.4483


100%|██████████| 144/144 [00:04<00:00, 30.53it/s]


Val Accuracy = 70.83%
Epoch 2/100; Loss: 4.6017


100%|██████████| 144/144 [00:04<00:00, 30.56it/s]


Val Accuracy = 77.08%
Epoch 3/100; Loss: 4.9686


100%|██████████| 144/144 [00:04<00:00, 30.32it/s]


Val Accuracy = 75.69%
Epoch 4/100; Loss: 4.6010


100%|██████████| 144/144 [00:04<00:00, 30.63it/s]


Val Accuracy = 81.94%
Epoch 5/100; Loss: 4.3062


100%|██████████| 144/144 [00:04<00:00, 30.47it/s]


Val Accuracy = 81.94%
Epoch 6/100; Loss: 4.2101


100%|██████████| 144/144 [00:04<00:00, 31.44it/s]


Val Accuracy = 81.25%
Epoch 7/100; Loss: 4.0474


100%|██████████| 144/144 [00:04<00:00, 30.60it/s]


Val Accuracy = 88.89%
Epoch 8/100; Loss: 4.1330


100%|██████████| 144/144 [00:04<00:00, 30.91it/s]


Val Accuracy = 86.81%
Epoch 9/100; Loss: 3.8774


100%|██████████| 144/144 [00:04<00:00, 30.94it/s]


Val Accuracy = 87.50%
Epoch 10/100; Loss: 3.9955


100%|██████████| 144/144 [00:04<00:00, 30.81it/s]


Val Accuracy = 89.58%
Epoch 11/100; Loss: 3.8517


100%|██████████| 144/144 [00:04<00:00, 30.90it/s]


Val Accuracy = 89.58%
Epoch 12/100; Loss: 3.7005


100%|██████████| 144/144 [00:04<00:00, 31.53it/s]


Val Accuracy = 90.97%
Epoch 13/100; Loss: 3.6766


100%|██████████| 144/144 [00:04<00:00, 31.05it/s]


Val Accuracy = 89.58%
Epoch 14/100; Loss: 3.6302


100%|██████████| 144/144 [00:04<00:00, 30.60it/s]


Val Accuracy = 93.06%
Epoch 15/100; Loss: 3.5285


100%|██████████| 144/144 [00:04<00:00, 30.63it/s]


Val Accuracy = 92.36%
Epoch 16/100; Loss: 3.5902


100%|██████████| 144/144 [00:04<00:00, 31.02it/s]


Val Accuracy = 93.75%
Epoch 17/100; Loss: 3.5092


100%|██████████| 144/144 [00:04<00:00, 31.34it/s]


Val Accuracy = 93.75%
Epoch 18/100; Loss: 3.5505


100%|██████████| 144/144 [00:04<00:00, 30.77it/s]


Val Accuracy = 91.67%
Epoch 19/100; Loss: 3.4863


100%|██████████| 144/144 [00:04<00:00, 31.14it/s]


Val Accuracy = 95.14%
Epoch 20/100; Loss: 3.4844


100%|██████████| 144/144 [00:04<00:00, 31.63it/s]


Val Accuracy = 94.44%
Epoch 21/100; Loss: 3.5137


100%|██████████| 144/144 [00:04<00:00, 31.37it/s]


Val Accuracy = 95.83%
Epoch 22/100; Loss: 3.4773


100%|██████████| 144/144 [00:04<00:00, 30.80it/s]


Val Accuracy = 94.44%
Epoch 23/100; Loss: 3.4266


100%|██████████| 144/144 [00:04<00:00, 31.04it/s]


Val Accuracy = 96.53%
Epoch 24/100; Loss: 3.4268


100%|██████████| 144/144 [00:04<00:00, 30.93it/s]


Val Accuracy = 93.06%
Epoch 25/100; Loss: 3.3806


100%|██████████| 144/144 [00:04<00:00, 30.76it/s]


Val Accuracy = 95.83%
Epoch 26/100; Loss: 3.3912


100%|██████████| 144/144 [00:04<00:00, 31.14it/s]


Val Accuracy = 95.14%
Epoch 27/100; Loss: 3.4341


100%|██████████| 144/144 [00:04<00:00, 30.80it/s]


Val Accuracy = 95.14%
Epoch 28/100; Loss: 3.4180


100%|██████████| 144/144 [00:04<00:00, 31.04it/s]


Val Accuracy = 94.44%
Epoch 29/100; Loss: 3.4624


100%|██████████| 144/144 [00:04<00:00, 31.48it/s]


Val Accuracy = 95.83%
Epoch 30/100; Loss: 3.4977


100%|██████████| 144/144 [00:04<00:00, 30.95it/s]


Val Accuracy = 94.44%
Epoch 31/100; Loss: 3.4210


100%|██████████| 144/144 [00:04<00:00, 30.10it/s]


Val Accuracy = 93.75%
Epoch 32/100; Loss: 3.4368


100%|██████████| 144/144 [00:04<00:00, 31.08it/s]


Val Accuracy = 95.14%
Epoch 33/100; Loss: 3.4771


100%|██████████| 144/144 [00:04<00:00, 30.48it/s]


Val Accuracy = 93.75%
Epoch 34/100; Loss: 3.4816


100%|██████████| 144/144 [00:04<00:00, 31.00it/s]


Val Accuracy = 94.44%
Epoch 35/100; Loss: 3.4223


100%|██████████| 144/144 [00:04<00:00, 31.47it/s]


Val Accuracy = 94.44%
Epoch 36/100; Loss: 3.4088


100%|██████████| 144/144 [00:04<00:00, 31.51it/s]


Val Accuracy = 95.83%
Epoch 37/100; Loss: 3.3835


100%|██████████| 144/144 [00:04<00:00, 30.86it/s]


Val Accuracy = 94.44%
Epoch 38/100; Loss: 3.4569


100%|██████████| 144/144 [00:04<00:00, 30.71it/s]


Val Accuracy = 92.36%
Epoch 39/100; Loss: 3.5227


100%|██████████| 144/144 [00:04<00:00, 30.63it/s]


Val Accuracy = 94.44%
Epoch 40/100; Loss: 3.5630


100%|██████████| 144/144 [00:04<00:00, 30.65it/s]


Val Accuracy = 93.75%
Epoch 41/100; Loss: 3.5425


100%|██████████| 144/144 [00:04<00:00, 31.20it/s]


Val Accuracy = 94.44%
Epoch 42/100; Loss: 3.4154


100%|██████████| 144/144 [00:04<00:00, 31.36it/s]


Val Accuracy = 95.14%
Epoch 43/100; Loss: 3.4311


100%|██████████| 144/144 [00:04<00:00, 31.24it/s]


Val Accuracy = 94.44%
Epoch 44/100; Loss: 3.3726


100%|██████████| 144/144 [00:04<00:00, 29.92it/s]


Val Accuracy = 95.14%
Epoch 45/100; Loss: 3.3711


100%|██████████| 144/144 [00:04<00:00, 31.43it/s]


Val Accuracy = 96.53%
Epoch 46/100; Loss: 3.3697


100%|██████████| 144/144 [00:04<00:00, 31.63it/s]


Val Accuracy = 95.83%
Epoch 47/100; Loss: 3.3753


100%|██████████| 144/144 [00:04<00:00, 31.50it/s]


Val Accuracy = 97.22%
Epoch 48/100; Loss: 3.3526


100%|██████████| 144/144 [00:04<00:00, 31.41it/s]


Val Accuracy = 97.22%
Epoch 49/100; Loss: 3.3425


100%|██████████| 144/144 [00:04<00:00, 30.80it/s]


Val Accuracy = 95.83%
Epoch 50/100; Loss: 3.3395


100%|██████████| 144/144 [00:04<00:00, 31.06it/s]


Val Accuracy = 95.83%
Epoch 51/100; Loss: 3.3251


100%|██████████| 144/144 [00:04<00:00, 31.24it/s]


Val Accuracy = 96.53%
Epoch 52/100; Loss: 3.3374


100%|██████████| 144/144 [00:04<00:00, 31.58it/s]


Val Accuracy = 95.83%
Epoch 53/100; Loss: 3.3214


100%|██████████| 144/144 [00:04<00:00, 31.07it/s]


Val Accuracy = 97.92%
Epoch 54/100; Loss: 3.3226


100%|██████████| 144/144 [00:04<00:00, 31.27it/s]


Val Accuracy = 96.53%
Epoch 55/100; Loss: 3.3391


100%|██████████| 144/144 [00:04<00:00, 31.06it/s]


Val Accuracy = 97.92%
Epoch 56/100; Loss: 3.3196


100%|██████████| 144/144 [00:04<00:00, 31.21it/s]


Val Accuracy = 97.22%
Epoch 57/100; Loss: 3.3613


100%|██████████| 144/144 [00:04<00:00, 30.74it/s]


Val Accuracy = 95.83%
Epoch 58/100; Loss: 3.3374


100%|██████████| 144/144 [00:04<00:00, 30.37it/s]


Val Accuracy = 97.22%
Epoch 59/100; Loss: 3.3361


100%|██████████| 144/144 [00:04<00:00, 30.34it/s]


Val Accuracy = 96.53%
Epoch 60/100; Loss: 3.3282


100%|██████████| 144/144 [00:04<00:00, 30.73it/s]


Val Accuracy = 95.83%
Epoch 61/100; Loss: 3.3148


100%|██████████| 144/144 [00:04<00:00, 31.38it/s]


Val Accuracy = 97.22%
Epoch 62/100; Loss: 3.3428


100%|██████████| 144/144 [00:04<00:00, 29.74it/s]


Val Accuracy = 96.53%
Epoch 63/100; Loss: 3.3543


100%|██████████| 144/144 [00:04<00:00, 30.63it/s]


Val Accuracy = 97.22%
Epoch 64/100; Loss: 3.3154


100%|██████████| 144/144 [00:04<00:00, 30.11it/s]


Val Accuracy = 96.53%
Epoch 65/100; Loss: 3.3337


100%|██████████| 144/144 [00:04<00:00, 31.28it/s]


Val Accuracy = 96.53%
Epoch 66/100; Loss: 3.3659


100%|██████████| 144/144 [00:04<00:00, 31.05it/s]


Val Accuracy = 97.22%
Epoch 67/100; Loss: 3.3243


100%|██████████| 144/144 [00:04<00:00, 30.18it/s]


Val Accuracy = 97.22%
Epoch 68/100; Loss: 3.3591


100%|██████████| 144/144 [00:04<00:00, 29.74it/s]


Val Accuracy = 97.22%
Epoch 69/100; Loss: 3.3331


100%|██████████| 144/144 [00:04<00:00, 30.43it/s]


Val Accuracy = 95.83%
Epoch 70/100; Loss: 3.3365


100%|██████████| 144/144 [00:04<00:00, 30.79it/s]


Val Accuracy = 97.22%
Epoch 71/100; Loss: 3.3149


100%|██████████| 144/144 [00:04<00:00, 31.13it/s]


Val Accuracy = 97.22%
Epoch 72/100; Loss: 3.3107


100%|██████████| 144/144 [00:04<00:00, 29.79it/s]


Val Accuracy = 95.83%
Epoch 73/100; Loss: 3.3287


100%|██████████| 144/144 [00:04<00:00, 31.11it/s]


Val Accuracy = 97.22%
Epoch 74/100; Loss: 3.3342


100%|██████████| 144/144 [00:04<00:00, 30.44it/s]


Val Accuracy = 97.92%
Epoch 75/100; Loss: 3.3432


100%|██████████| 144/144 [00:04<00:00, 31.30it/s]


Val Accuracy = 97.22%
Epoch 76/100; Loss: 3.3331


100%|██████████| 144/144 [00:04<00:00, 31.05it/s]


Val Accuracy = 97.22%
Epoch 77/100; Loss: 3.3106


100%|██████████| 144/144 [00:04<00:00, 30.61it/s]


Val Accuracy = 97.22%
Epoch 78/100; Loss: 3.3246


100%|██████████| 144/144 [00:04<00:00, 30.98it/s]


Val Accuracy = 96.53%
Epoch 79/100; Loss: 3.3416


100%|██████████| 144/144 [00:04<00:00, 31.21it/s]


Val Accuracy = 97.92%
Epoch 80/100; Loss: 3.3069


100%|██████████| 144/144 [00:04<00:00, 30.92it/s]


Val Accuracy = 96.53%
Epoch 81/100; Loss: 3.3052


100%|██████████| 144/144 [00:04<00:00, 30.89it/s]


Val Accuracy = 96.53%
Epoch 82/100; Loss: 3.3389


100%|██████████| 144/144 [00:04<00:00, 31.08it/s]


Val Accuracy = 97.22%
Epoch 83/100; Loss: 3.3044


100%|██████████| 144/144 [00:04<00:00, 30.72it/s]


Val Accuracy = 97.22%
Epoch 84/100; Loss: 3.3432


100%|██████████| 144/144 [00:04<00:00, 30.86it/s]


Val Accuracy = 97.22%
Epoch 85/100; Loss: 3.3042


100%|██████████| 144/144 [00:04<00:00, 30.65it/s]


Val Accuracy = 97.22%
Epoch 86/100; Loss: 3.3228


100%|██████████| 144/144 [00:04<00:00, 31.28it/s]


Val Accuracy = 97.22%
Epoch 87/100; Loss: 3.3058


100%|██████████| 144/144 [00:04<00:00, 31.27it/s]


Val Accuracy = 97.22%
Epoch 88/100; Loss: 3.3036


100%|██████████| 144/144 [00:04<00:00, 31.13it/s]


Val Accuracy = 97.22%
Epoch 89/100; Loss: 3.3442


100%|██████████| 144/144 [00:04<00:00, 29.93it/s]


Val Accuracy = 97.92%
Epoch 90/100; Loss: 3.3215


100%|██████████| 144/144 [00:04<00:00, 31.13it/s]


Val Accuracy = 97.92%
Epoch 91/100; Loss: 3.3263


100%|██████████| 144/144 [00:04<00:00, 31.07it/s]


Val Accuracy = 97.22%
Epoch 92/100; Loss: 3.3078


100%|██████████| 144/144 [00:04<00:00, 30.58it/s]


Val Accuracy = 96.53%
Epoch 93/100; Loss: 3.3199


100%|██████████| 144/144 [00:04<00:00, 30.64it/s]


Val Accuracy = 97.22%
Epoch 94/100; Loss: 3.3106


100%|██████████| 144/144 [00:04<00:00, 31.19it/s]


Val Accuracy = 97.22%
Epoch 95/100; Loss: 3.3137


100%|██████████| 144/144 [00:04<00:00, 30.81it/s]


Val Accuracy = 96.53%
Epoch 96/100; Loss: 3.3199


100%|██████████| 144/144 [00:04<00:00, 30.25it/s]


Val Accuracy = 96.53%
Epoch 97/100; Loss: 3.3063


100%|██████████| 144/144 [00:04<00:00, 30.91it/s]


Val Accuracy = 96.53%
Epoch 98/100; Loss: 3.3425


100%|██████████| 144/144 [00:04<00:00, 31.16it/s]


Val Accuracy = 97.92%
Epoch 99/100; Loss: 3.3216


100%|██████████| 144/144 [00:04<00:00, 30.01it/s]

Val Accuracy = 97.22%
Epoch 100/100; Loss: 3.3263





In [24]:
model.load_state_dict(best_weights)
model.eval()

GRU(
  (gru): GRU(12, 256, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=512, out_features=12, bias=True)
)

In [25]:
count = 0
with torch.no_grad():
    for y in tqdm(train_data):
        model.eval()
        pred = predict(model, y[0].unsqueeze(0), False)
        min_val = 120
        min_key = ''
        for key, val in chord_templates.items():
            out = torch.norm(torch.Tensor(val) - y[1])
            if min_val >= out:
                min_val = out
                min_key = key
        count+=(pred==min_key)
print(f"Train Accuracy = {100*count/len(train_data):.2f}%")

100%|██████████| 576/576 [00:19<00:00, 30.28it/s]

Train Accuracy = 99.48%





In [26]:
y = librosa.load('./data/F#-Aug-1.wav')[0]
predict(model, y)

'F#aug'

In [27]:
count = 0
with torch.no_grad():
    for y in tqdm(val_data):
        model.eval()
        pred = predict(model, y[0].unsqueeze(0), False)
        min_val = 120
        min_key = ''
        for key, val in chord_templates.items():
            out = torch.norm(torch.Tensor(val) - y[1])
            if min_val >= out:
                min_val = out
                min_key = key
        count+=(pred==min_key)
print(f"Val Accuracy = {100*count/len(val_data):.2f}%")

100%|██████████| 144/144 [00:04<00:00, 30.00it/s]

Val Accuracy = 97.92%





In [28]:
torch.save(model.state_dict(), './models/chord_detector.pth')

In [29]:
model = GRU()
model.load_state_dict(torch.load('./models/chord_detector.pth'))
model.eval()

GRU(
  (gru): GRU(12, 256, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=512, out_features=12, bias=True)
)

In [31]:
y = librosa.load('./data/G-Aug-12.wav')[0]
predict(model, y)

'Gaug'