In [2]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
import json
import os
from sklearn.model_selection import train_test_split
import soundfile as sf

In [32]:
# !!! Code in this cell is for chunking the dataset into little bits - you'll only need to run this once

from pathlib import Path

data_location = "../../Datasets/musdb18-wav/train/"
out_location = "../../Datasets/musdb18-wav/"
folders = os.listdir(data_location)

for folder in folders:
    other_y, _ = librosa.load(data_location + '/' + folder + '/other.wav', sr=22050)
    vocals_y, _ = librosa.load(data_location + '/' + folder + '/vocals.wav', sr=22050)
    chunk_length = 22050 * 2
    nchunks = int(other_y.shape[0] / chunk_length)

    # Path(data_location + folder + '/chunks').mkdir(exist_ok=True)
    for chunk in range(nchunks):
        other_chunk = other_y[chunk * chunk_length:(chunk * chunk_length) + chunk_length]
        vocal_chunk = vocals_y[chunk * chunk_length:(chunk * chunk_length) + chunk_length]
        sf.write(out_location + 'chunks_other/chunked_other_' + folder + '_' + str(chunk) + '.wav', other_chunk, 22050)
        sf.write(out_location + 'chunks_vocal/chunked_vocal_' + folder + '_' + str(chunk) + '.wav', vocal_chunk, 22050)


In [31]:
def chroma_predict(model, audio, chroma_req = True, chord_templates:dict = json.load(open('./chord_templates.json')), sr = 44100, hop = 256):
    if chroma_req:
        chroma = torch.Tensor(librosa.feature.chroma_cens(y=audio, sr = sr, hop_length=hop)).T.unsqueeze(0)
    else:
        chroma = audio
    with torch.no_grad():
        outputs = nn.functional.softmax(model(chroma), 1)[0]
    min_val = 120
    min_key = ''
    for key, val in chord_templates.items():
        out = torch.norm(torch.Tensor(val) - outputs)
        if min_val >= out:
            min_val = out
            min_key = key
    return min_key


In [89]:
class GRU(nn.Module):
    def __init__(self, input_size = 12, hidden_size = 256, num_layers = 2, num_classes = 12, bidirectional = True) -> None:
        super(GRU, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional

        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first = True, bidirectional=bidirectional)
        if(bidirectional):
            self.fc = nn.Linear(hidden_size*2, num_classes)
        else:
            self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        if(self.bidirectional):
            h0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size)
        else:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.gru(x, h0)
        out = out[:,-1,:] # Since we only want the output of the last cell
        out = self.fc(out)
        return(out)


In [85]:
class Musdb18Dataset(Dataset):
    def __init__(self, data_location):
        super().__init__()
        self.path = data_location
        self.vocal_files = os.listdir(data_location + '/chunks_vocal')
        self.other_files = os.listdir(data_location + '/chunks_other')
        self.chroma_gru = GRU()
        self.chroma_gru.load_state_dict(torch.load('models/chord_detector.pth'))
        
    def __len__(self):
        return len(self.vocal_files)
    
    def __getitem__(self, index):
        other_y, sr = librosa.load(self.path + '/chunks_other/' + self.other_files[index], sr=22050)
        vocals_y, sr = librosa.load(self.path + '/chunks_vocal/' + self.vocal_files[index], sr=22050)
        
        other_chroma = librosa.feature.chroma_cens(y=other_y, sr = sr, hop_length=256)
        vocals_chroma = librosa.feature.chroma_cens(y=vocals_y, sr = sr, hop_length=256)
        
        ground_truth = chroma_predict(self.chroma_gru, other_y)
        return torch.tensor(vocals_chroma), ground_truth

In [121]:
class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.f_dim = 8
        
        self.conv1 = nn.Conv1d(12, 24, 3, 1)
        self.conv2 = nn.Conv1d(24, 48, 3, 1)
        self.conv3 = nn.Conv1d(48, 96, 2, 1)
        self.FC = nn.Linear(self.f_dim * 12, 24)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        h1 = self.relu(self.conv1(x))
        h2 = self.relu(self.conv2(h1))
        h3 = self.relu(self.conv3(h2))
        flat = torch.flatten(h3, 1)
        h4 = self.FC(flat)
        return self.softmax(h4) 


In [92]:
def truth_label_to_int(gt):
    template = json.load(open('./chord_templates.json'))
    int_gt = torch.zeros(len(gt))
    for i in range(len(gt)):
        idx = list(template.keys()).index(gt[i])
        int_gt[i] = idx
    return int_gt.long()
    
def train():
    dataset = Musdb18Dataset('../../Datasets/musdb18-wav')
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=8,
                                               shuffle=True)
    pred = Predictor()
    adam = torch.optim.Adam(pred.parameters(), lr=0.0001)
    loss_fn = nn.CrossEntropyLoss()
    
    for epoch in range(100):
        epoch_loss = 0
        count = 0
        for i, (X, y) in enumerate(train_loader):
            ground_truth = truth_label_to_int(y)
            out = pred(X)
            loss = loss_fn(out, ground_truth)
            loss.backward()
            adam.step()
            adam.zero_grad()
            epoch_loss += loss.detach()
            count += 1
        print(epoch_loss / count)
        
    torch.save(pred.state_dict(), 'models/predictor.model')

In [93]:
train()

  return self.softmax(h4)


tensor(3.1462)
tensor(3.1351)
tensor(3.1345)
tensor(3.1344)
tensor(3.1342)
tensor(3.1338)
tensor(3.1338)
tensor(3.1333)
tensor(3.1333)
tensor(3.1331)
tensor(3.1329)
tensor(3.1325)
tensor(3.1321)
tensor(3.1314)
tensor(3.1312)


KeyboardInterrupt: 