In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
import json
import os
import pickle
import sys
from pathlib import Path
from sklearn.model_selection import train_test_split
import soundfile as sf
from utils import GRU

In [2]:
f = '../../../../Music Technology/Datasets/musdb18hq/'
os.listdir(f)

['.DS_Store',
 'chunks_vocal',
 'test',
 'train_data.pt',
 'test_data.pt',
 'chunks_chord',
 'model_train',
 'train',
 'model_test']

## Creating the chunks

In [3]:
SR = 44100
HOP = 256
FRAMES = 6

In [4]:
chord_detector = GRU()
chord_detector.load_state_dict(torch.load('./models/chord_detector.pth'))
chord_detector.eval()

GRU(
  (gru): GRU(12, 256, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=512, out_features=12, bias=True)
)

In [5]:
def predict(model, audio, chroma_req = True, chord_templates:dict = json.load(open('./chord_templates.json')), sr = SR, hop = HOP):
    if chroma_req:
        chroma = torch.Tensor(librosa.feature.chroma_cens(y=audio, sr = sr, hop_length=hop)).T.unsqueeze(0)
    else:
        chroma = audio
    with torch.no_grad():
        outputs = nn.functional.softmax(model(chroma), 1)[0]
    min_val = 120
    min_key = ''
    for key, val in chord_templates.items():
        out = torch.norm(torch.Tensor(val) - outputs)
        if min_val >= out:
            min_val = out
            min_key = key
    return min_key

In [6]:
def prediction(model, chroma, frame = 6):
    stack = []
    time = []
    model.eval()
    pred = predict(model, chroma[:frame, :].unsqueeze(0), False)
    prev_pred = pred
    dur = 1
    main_sub = 0
    for i in tqdm(range(frame, chroma.shape[0]-frame+1, frame)):
        model.eval()
        pred = predict(model, chroma[i:i+frame, :].unsqueeze(0), False)
        if(pred != prev_pred):
            if(dur>10):
                if(len(stack)==0):
                    stack.append(prev_pred)
                elif(stack[-1]==prev_pred):
                    dur = 0
                    prev_pred = pred
                    continue
                else:
                    stack.append(prev_pred)
                if len(time)!=0:
                    time.append((i)*HOP/SR - main_sub)
                else:
                    main_sub = (i)*HOP/SR
                    time.append(0.0)
            dur = 0
            prev_pred = pred
        dur+=1
    return stack, time

In [None]:
# # !!! Code in this cell is for chunking the dataset into little bits - you'll only need to run this once

# 

# data_location = "../../../../Music Technology/Datasets/musdb18hq/test/"
# out_location = "../../../../Music Technology/Datasets/musdb18hq/"
# folders = os.listdir(data_location)
# count = 0

# # print(chord_templates["F#"])

# for folder in folders:
#     if not os.path.isdir(data_location+folder):
#         continue
#     mixture_y, _ = librosa.load(data_location + '/' + folder + '/mixture.wav', sr=SR)
#     vocals_y, _ = librosa.load(data_location + '/' + folder + '/vocals.wav', sr=SR)
#     mixture_y = mixture_y/np.max(np.abs(mixture_y))
#     vocals_y = vocals_y/np.max(np.abs(vocals_y))

#     mixture_chroma = torch.Tensor(librosa.feature.chroma_cens(y=mixture_y, sr = SR, hop_length=HOP)).T
#     # vocals_chroma = torch.Tensor(librosa.feature.chroma_cens(y=vocals_y, sr = SR, hop_length=HOP)).T
#     chunk_length = FRAMES
#     nchunks = mixture_chroma.shape[0] // chunk_length # no padding

#     if not os.path.isdir(out_location+'chunks_chord_test/'):
#         os.mkdir(out_location+'chunks_chord_test')
#     if not os.path.isdir(out_location+'chunks_vocal_test/'):
#         os.mkdir(out_location+'chunks_vocal_test')
    
#     # Get chords from mixture chroma
#     chord_stack, time = prediction(chord_detector, mixture_chroma)
#     frame_num = np.array([int(i/((HOP/SR)*6)) for i in time])
#     chord_stack = np.array([frame_num, chord_stack]).T
#     chords = []
#     for prev, curr in zip(chord_stack[:-1], chord_stack[1:]):
#         frame_diff = int(curr[0]) - int(prev[0])
#         chords.extend([prev[1] for _ in range(frame_diff)])
#     chords.extend([chord_stack[-1][1] for _ in range(nchunks - len(chords))])
#     with open(out_location+"chunks_chord_test/chord_"+str(count), "wb") as fp:
#         pickle.dump(chords, fp)
#     print(len(chords), nchunks)
#     sf.write(out_location + 'chunks_vocal_test/vocal_' + str(count)+'.wav', vocals_y, SR)
    
#     # for i in range(0, mixture_chroma.shape[0]-chunk_length+1,chunk_length):
#     #     vocal_chunk = vocals_chroma[i:i+chunk_length,:]
#     #     print(vocal_chunk.shape)\
#         # torch.save(vocal_chunk, out_location + 'chunks_vocal/vocal_' + str(count) + '_' + str(i//chunk_length) + '.pt')
#     count+=1

In [48]:
# # !!! Code in this cell is for chunking the dataset into little bits - you'll only need to run this once

# from pathlib import Path

# data_location = "../../../../Music Technology/Datasets/musdb18hq/train/"
# out_location = "../../../../Music Technology/Datasets/musdb18hq/"
# folders = os.listdir(data_location)
# count = 0

# # print(chord_templates["F#"])

# for folder in folders:
#     if not os.path.isdir(data_location+folder):
#         continue
#     mixture_y, _ = librosa.load(data_location + '/' + folder + '/mixture.wav', sr=SR)
#     vocals_y, _ = librosa.load(data_location + '/' + folder + '/vocals.wav', sr=SR)
#     mixture_y = mixture_y/np.max(np.abs(mixture_y))
#     vocals_y = vocals_y/np.max(np.abs(vocals_y))

#     mixture_chroma = torch.Tensor(librosa.feature.chroma_cens(y=mixture_y, sr = SR, hop_length=HOP)).T
#     # vocals_chroma = torch.Tensor(librosa.feature.chroma_cens(y=vocals_y, sr = SR, hop_length=HOP)).T
#     chunk_length = FRAMES
#     nchunks = mixture_chroma.shape[0] // chunk_length # no padding

#     if not os.path.isdir(out_location+'chunks_chord/'):
#         os.mkdir(out_location+'chunks_chord')
#     if not os.path.isdir(out_location+'chunks_vocal/'):
#         os.mkdir(out_location+'chunks_vocal')
    
#     # Get chords from mixture chroma
#     chord_stack, time = prediction(chord_detector, mixture_chroma)
#     frame_num = np.array([int(i/((HOP/SR)*6)) for i in time])
#     chord_stack = np.array([frame_num, chord_stack]).T
#     chords = []
#     for prev, curr in zip(chord_stack[:-1], chord_stack[1:]):
#         frame_diff = int(curr[0]) - int(prev[0])
#         chords.extend([prev[1] for _ in range(frame_diff)])
#     chords.extend([chord_stack[-1][1] for _ in range(nchunks - len(chords))])
#     with open(out_location+"chunks_chord/chord_"+str(count), "wb") as fp:
#         pickle.dump(chords, fp)
#     print(len(chords), nchunks)
#     sf.write(out_location + 'chunks_vocal/vocal_' + str(count)+'.wav', vocals_y, SR)
    
#     # for i in range(0, mixture_chroma.shape[0]-chunk_length+1,chunk_length):
#     #     vocal_chunk = vocals_chroma[i:i+chunk_length,:]
#     #     print(vocal_chunk.shape)\
#         # torch.save(vocal_chunk, out_location + 'chunks_vocal/vocal_' + str(count) + '_' + str(i//chunk_length) + '.pt')
#     count+=1

100%|██████████| 5708/5708 [00:08<00:00, 685.72it/s]


5709 5709


100%|██████████| 7017/7017 [00:10<00:00, 668.02it/s]


7018 7018


100%|██████████| 6351/6351 [00:09<00:00, 683.89it/s]


6352 6352


100%|██████████| 6361/6361 [00:09<00:00, 700.20it/s]


6362 6362


100%|██████████| 6528/6528 [00:09<00:00, 691.39it/s]


6529 6529


100%|██████████| 5756/5756 [00:08<00:00, 686.24it/s]


5757 5757


100%|██████████| 5898/5898 [00:08<00:00, 693.93it/s]


5899 5899


100%|██████████| 7683/7683 [00:11<00:00, 698.15it/s]


7684 7684


100%|██████████| 7090/7090 [00:10<00:00, 704.15it/s]


7091 7091


100%|██████████| 8234/8234 [00:11<00:00, 696.97it/s]


8235 8235


100%|██████████| 8776/8776 [00:12<00:00, 706.23it/s]


8777 8777


100%|██████████| 9125/9125 [00:13<00:00, 698.77it/s]


9126 9126


100%|██████████| 5092/5092 [00:07<00:00, 694.61it/s]


5093 5093


100%|██████████| 9108/9108 [00:12<00:00, 701.97it/s]


9109 9109


100%|██████████| 7813/7813 [00:11<00:00, 701.54it/s]


7814 7814


100%|██████████| 7267/7267 [00:10<00:00, 704.91it/s]


7268 7268


100%|██████████| 12355/12355 [00:17<00:00, 707.51it/s]


12356 12356


100%|██████████| 6298/6298 [00:08<00:00, 701.16it/s]


6299 6299


100%|██████████| 6092/6092 [00:08<00:00, 701.34it/s]


6093 6093


100%|██████████| 7751/7751 [00:11<00:00, 691.58it/s]


7752 7752


100%|██████████| 6749/6749 [00:09<00:00, 715.77it/s]


6750 6750


100%|██████████| 5048/5048 [00:07<00:00, 713.03it/s]


5049 5049


100%|██████████| 7264/7264 [00:10<00:00, 713.61it/s]


7265 7265


100%|██████████| 9675/9675 [00:13<00:00, 718.23it/s]


9676 9676


100%|██████████| 9195/9195 [00:13<00:00, 696.41it/s]


9196 9196


100%|██████████| 6737/6737 [00:09<00:00, 697.49it/s]


6738 6738


100%|██████████| 6031/6031 [00:08<00:00, 704.38it/s]


6032 6032


100%|██████████| 9515/9515 [00:13<00:00, 705.62it/s]


9516 9516


100%|██████████| 7082/7082 [00:10<00:00, 701.09it/s]


7083 7083


100%|██████████| 5468/5468 [00:07<00:00, 696.20it/s]


5469 5469


100%|██████████| 7918/7918 [00:11<00:00, 704.54it/s]


7919 7919


100%|██████████| 8084/8084 [00:11<00:00, 700.62it/s]


8085 8085


100%|██████████| 6728/6728 [00:09<00:00, 695.71it/s]


6729 6729


100%|██████████| 4066/4066 [00:05<00:00, 688.20it/s]


4067 4067


100%|██████████| 7295/7295 [00:10<00:00, 703.99it/s]


7296 7296


100%|██████████| 8976/8976 [00:12<00:00, 706.31it/s]


8977 8977


100%|██████████| 6995/6995 [00:10<00:00, 698.22it/s]


6996 6996


100%|██████████| 11359/11359 [00:16<00:00, 709.25it/s]


11360 11360


100%|██████████| 9867/9867 [00:13<00:00, 706.08it/s]


9868 9868


100%|██████████| 8406/8406 [00:11<00:00, 705.52it/s]


8407 8407


100%|██████████| 5155/5155 [00:07<00:00, 702.14it/s]


5156 5156


100%|██████████| 7317/7317 [00:10<00:00, 698.43it/s]


7318 7318


100%|██████████| 6319/6319 [00:09<00:00, 700.92it/s]


6320 6320


100%|██████████| 7183/7183 [00:10<00:00, 705.43it/s]


7184 7184


100%|██████████| 7208/7208 [00:10<00:00, 699.41it/s]


7209 7209


100%|██████████| 5968/5968 [00:08<00:00, 696.25it/s]


5969 5969


100%|██████████| 5367/5367 [00:07<00:00, 692.20it/s]


5368 5368


100%|██████████| 7854/7854 [00:10<00:00, 714.88it/s]


7855 7855


100%|██████████| 4670/4670 [00:06<00:00, 720.92it/s]


4671 4671


100%|██████████| 2186/2186 [00:03<00:00, 709.51it/s]

2187 2187





In [63]:
class MelChordDataset(Dataset):
    def __init__(
            self, 
            data_location = "../../../../Music Technology/Datasets/musdb18hq/chunks_vocal/",
            out_location = "../../../../Music Technology/Datasets/musdb18hq/chunks_chord/",
            frames_per_chord = 6
        ):
        super(MelChordDataset).__init__()
        self.frames_per_chord = frames_per_chord
        vocals_y = []
        vocals_chroma = []
        chord_templates:dict = json.load(open('./chord_templates.json'))
        act_chord_data = []

        for i in range(len(os.listdir(data_location))): # 100
            with open(out_location+"chord_"+str(i), "rb") as fp:
                chord_data = pickle.load(fp)
            act_chord_data.append(torch.Tensor(np.array([np.array(chord_templates[i]) for i in chord_data])))
            vocals_y.append(librosa.load(data_location + 'vocal_'+str(i)+'.wav', sr=SR)[0])
            vocals_chroma.append(torch.Tensor(librosa.feature.chroma_cens(y=vocals_y[-1], sr = SR, hop_length=HOP)).T)
        
        # act_chord_data[i]: Shape: (num_chords[i], 12)
        # vocals_chroma[i]: Shape: (num_frames[i], 12)
        # num_chords[i] = (num_frames[i] // frames_per_chord)

        self.data = []
        self._create_data(act_chord_data, vocals_chroma)
    
    def _create_data(self, chord_data, chroma_data):
        for (chroma, chords) in zip(chroma_data, chord_data):
            for i in range(0, chroma.shape[0]-self.frames_per_chord, self.frames_per_chord):
                block_chroma = chroma[i:i+self.frames_per_chord,:]
                block_chord = chords[i//self.frames_per_chord]
                if(block_chroma.any()):
                    self.data.append((block_chroma, block_chord))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

In [64]:
train_data = MelChordDataset()
test_data = MelChordDataset(
    data_location="../../../../Music Technology/Datasets/musdb18hq/chunks_vocal_test/",
    out_location="../../../../Music Technology/Datasets/musdb18hq/chunks_chord_test/"
)
torch.save(train_data, './data/final/train_data.pt')
torch.save(test_data, './data/final/test_data.pt')

In [65]:
print(len(train_data))
print(len(test_data))

576514
297423


In [32]:
class GRU(nn.Module):
    def __init__(self, input_size = 12, hidden_size = 64, num_layers = 1, num_classes = 12, bidirectional = True) -> None:
        super(GRU, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional

        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first = True, bidirectional=bidirectional)
        if(bidirectional):
            self.fc = nn.Linear(hidden_size*2, num_classes)
        else:
            self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        if(self.bidirectional):
            h0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size)
        else:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.gru(x, h0)
        out = out[:,-1,:] # Since we only want the output of the last cell
        out = self.fc(out)
        return(out)


In [67]:
class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.f_dim = 19
        
        self.conv1 = nn.Conv1d(12, 24, 5, 2)
        self.conv2 = nn.Conv1d(24, 48, 5, 2)
        self.conv3 = nn.Conv1d(48, 12, 5, 2)
        self.FC = nn.Linear(self.f_dim * 12, 24)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        h1 = self.relu(self.conv1(x))
        h2 = self.relu(self.conv2(h1))
        h3 = self.relu(self.conv3(h2))
        flat = torch.flatten(h3, 1)
        h4 = self.FC(flat)
        return self.softmax(h4) 


In [74]:
def truth_label_to_int(gt):
    template = json.load(open('./chord_templates.json'))
    int_gt = torch.zeros(len(gt))
    for i in range(len(gt)):
        idx = list(template.keys()).index(gt[i])
        int_gt[i] = idx
    return int_gt.long()
    
def train():
    dataset = Musdb18Dataset('../../Datasets/musdb18-wav')
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=8,
                                               shuffle=True)
    pred = Predictor()
    adam = torch.optim.Adam(pred.parameters(), lr=0.0001)
    loss_fn = nn.CrossEntropyLoss()
    
    for epoch in range(100):
        epoch_loss = 0
        count = 0
        for i, (X, y) in enumerate(train_loader):
            ground_truth = truth_label_to_int(y)
            out = pred(X)
            loss = loss_fn(out, ground_truth)
            loss.backward()
            adam.step()
            adam.zero_grad()
            epoch_loss += loss.detach()
            count += 1
            print(loss.detach())
        print(epoch_loss / count)
        
    torch.save(pred.state_dict(), 'models/predictor.model')

In [75]:
train()

  return self.softmax(h4)


tensor(3.1787)
tensor(3.1783)
tensor(3.1783)
tensor(3.1782)
tensor(3.1793)
tensor(3.1782)
tensor(3.1780)
tensor(3.1789)
tensor(3.1784)
tensor(3.1782)
tensor(3.1790)
tensor(3.1774)
tensor(3.1791)
tensor(3.1780)
tensor(3.1787)
tensor(3.1785)
tensor(3.1775)
tensor(3.1788)
tensor(3.1779)
tensor(3.1780)
tensor(3.1784)
tensor(3.1776)
tensor(3.1784)
tensor(3.1796)
tensor(3.1781)


KeyboardInterrupt: 