In [52]:
import pandas as pd
import torch
import torch.nn.functional as F
import numpy as np
import random
import sklearn.metrics as metrics
import json
import pretty_midi
import librosa

# import transformers and matplotlib in the cell when you call them, not globally

In [53]:
json_path = '/home/hice1/sgoel83/scratch/Maestro/maestro-v3.0.0/maestro-v3.0.0.json'
data_path = '/home/hice1/sgoel83/scratch/Maestro/maestro-v3.0.0/'

In [54]:
def load_train_data(json_path):
    with open(json_path, 'r') as file:
        train_data = json.load(file)
    return train_data

# train_data = load_train_data(json_path)ces once pickle file is complete 
# indices_2018 = [index for index, year in train_data.items() if year == 2018]

# MIDI2018 and CQT2018 final dicts

In [107]:
# MIDI Pickle file

import pickle
midi_final_dict = {}
with open('/home/hice1/amardia6/scratch/midi2018.pickle', 'rb') as f:
    while True:
        try:
            data = pickle.load(f)
            midi_final_dict = {**midi_final_dict, **data}
        except EOFError:
            break

In [108]:
print(len(midi_final_dict.keys()))
print(midi_final_dict.keys())

93
dict_keys([0, 18, 19, 21, 22, 23, 26, 35, 61, 69, 74, 75, 87, 88, 89, 97, 110, 112, 114, 187, 223, 234, 255, 257, 280, 317, 347, 377, 392, 393, 413, 414, 443, 444, 445, 446, 453, 454, 455, 456, 476, 479, 672, 678, 679, 691, 711, 717, 724, 729, 778, 782, 852, 873, 888, 920, 931, 934, 936, 939, 945, 972, 1003, 1009, 1014, 1015, 1017, 1022, 1031, 1051, 1054, 1094, 1095, 1106, 1109, 1110, 1138, 1141, 1142, 1144, 1159, 1165, 1166, 1181, 1194, 1199, 1200, 1216, 1224, 1240, 1253, 1255, 1274])


In [109]:
# CQT WAV Pickle file

import pickle
cqt_final_dict = {}
with open('/home/hice1/sgoel83/scratch/wav_pickle/88_bins/cqt2018_88.pickle', 'rb') as f:
    while True:
        try:
            cqt_data = pickle.load(f)
            cqt_final_dict = {**cqt_final_dict, **cqt_data}
        except EOFError:
            break

In [110]:
print(len(cqt_final_dict.keys()))
print(cqt_final_dict.keys())

93
dict_keys([0, 18, 19, 21, 22, 23, 26, 35, 61, 69, 74, 75, 87, 88, 89, 97, 110, 112, 114, 187, 223, 234, 255, 257, 280, 317, 347, 377, 392, 393, 413, 414, 443, 444, 445, 446, 453, 454, 455, 456, 476, 479, 672, 678, 679, 691, 711, 717, 724, 729, 778, 782, 852, 873, 888, 920, 931, 934, 936, 939, 945, 972, 1003, 1009, 1014, 1015, 1017, 1022, 1031, 1051, 1054, 1094, 1095, 1106, 1109, 1110, 1138, 1141, 1142, 1144, 1159, 1165, 1166, 1181, 1194, 1199, 1200, 1216, 1224, 1240, 1253, 1255, 1274])


# Stripping MIDI to 88 keys

In [111]:
# A0 is 21, C8 is 108
def strip_to_88_keys_in_place(midi_final_dict):
    for piece in midi_final_dict:
        if midi_final_dict[piece].shape[0] == 88:
            continue
        else:
            midi_final_dict[piece] = midi_final_dict[piece][21:109, :]

strip_to_88_keys_in_place(midi_final_dict=midi_final_dict)

for key, array in midi_final_dict.items():
    print(f"{key}: Shape {array.shape}")

0: Shape (88, 70398)
18: Shape (88, 70375)
19: Shape (88, 68555)
21: Shape (88, 112083)
22: Shape (88, 119524)
23: Shape (88, 111878)
26: Shape (88, 70149)
35: Shape (88, 52089)
61: Shape (88, 84826)
69: Shape (88, 34968)
74: Shape (88, 25165)
75: Shape (88, 29749)
87: Shape (88, 105461)
88: Shape (88, 102212)
89: Shape (88, 63386)
97: Shape (88, 29407)
110: Shape (88, 14140)
112: Shape (88, 19476)
114: Shape (88, 21659)
187: Shape (88, 161182)
223: Shape (88, 79511)
234: Shape (88, 170061)
255: Shape (88, 23549)
257: Shape (88, 27070)
280: Shape (88, 99133)
317: Shape (88, 56993)
347: Shape (88, 64512)
377: Shape (88, 224451)
392: Shape (88, 137190)
393: Shape (88, 184680)
413: Shape (88, 256228)
414: Shape (88, 242350)
443: Shape (88, 174353)
444: Shape (88, 174903)
445: Shape (88, 163145)
446: Shape (88, 175912)
453: Shape (88, 234858)
454: Shape (88, 208137)
455: Shape (88, 189899)
456: Shape (88, 199127)
476: Shape (88, 30586)
479: Shape (88, 20640)
672: Shape (88, 146129)
678: Sh

# Data preprocessing

In [112]:
# Shaping parameters
batch_size = 32
slices_per_batch = 8
time_slices = 256
values_per_slice_per_batch = 431
frequency_bins_x = 88
duration_x = 10
channels_x = 1
keys_y = 88

# time per slice in seconds = 9.995 s

In [113]:
x_parameters = {
    'batch_size': batch_size,
    'time_slices': time_slices,
    'frequency_bins': frequency_bins_x,
    'duration': duration_x,
    'channels': channels_x}

y_parameters = {
    'batch_size': batch_size,
    'time_slices': time_slices,
    'num_keys': keys_y
}

In [114]:
# Audio CQT stats

import matplotlib.pyplot as plt
# x : (88, 32, 360, 5, 1)
# y : (88, 32, 360)

x = []
xlen = []
count = 0
count_0 = 0
max_ret = 0
min_ret = float("inf")
for idx, audio in cqt_final_dict.items():
    # print("Shape 0:" , audio.shape[0])
    # print("Shape 1:" , audio.shape[1])
    count += audio.shape[1]
    count_0 += audio.shape[0]
    # print(audio.shape[1]*audio.shape[0])
    max_ret = max(max_ret, audio.shape[1])
    min_ret = min(min_ret, audio.shape[1])
    xlen.append((idx, audio.shape[1]))
print("min_ret", min_ret)
print("max_ret", max_ret)
print(count/93)
print(count_0/93)
xlen.sort(key=lambda a: a[1])
print(xlen)
# xs = [x for x in range(len(xlen))]

# plt.plot(xs, xlen)
# plt.show()

min_ret 6114
max_ret 110392
45800.76344086022
88.0
[(110, 6114), (1200, 7172), (112, 8432), (479, 8914), (114, 9328), (255, 10181), (1216, 10527), (1109, 10740), (74, 10839), (1199, 11047), (257, 11662), (97, 12708), (1181, 12740), (75, 12843), (476, 13183), (782, 13536), (691, 14231), (69, 15103), (945, 15560), (1110, 19083), (778, 19670), (35, 22447), (939, 22793), (317, 24551), (1240, 25119), (852, 26876), (89, 27304), (711, 27466), (347, 27795), (19, 29554), (1224, 30214), (26, 30254), (0, 30331), (18, 30351), (1106, 31389), (724, 31610), (1003, 32786), (223, 34286), (1255, 34646), (920, 34925), (873, 35988), (936, 36210), (61, 36536), (934, 36894), (931, 40758), (1159, 42115), (280, 42736), (88, 44062), (1015, 44136), (87, 45462), (1017, 45732), (1274, 46071), (1253, 46182), (729, 47059), (23, 48214), (1054, 48238), (21, 48310), (1014, 50813), (22, 51518), (1031, 54052), (972, 56333), (717, 57097), (1194, 58526), (392, 59126), (1051, 62550), (672, 62972), (1138, 65593), (1165, 658

In [115]:
# MIDI stats

import matplotlib.pyplot as plt
# x : (88, 32, 360, 5, 1)
# y : (88, 32, 360)

y = []
ylen = []
county = 0
county_0 = 0
max_ret_y = 0
min_ret_y = float("inf")
for idx, midi in midi_final_dict.items():
    # print("Shape 0:" , audio.shape[0])
    # print("Shape 1:" , audio.shape[1])
    county += midi.shape[1]
    county_0 += midi.shape[0]
    # print(audio.shape[1]*audio.shape[0])
    max_ret_y = max(max_ret_y, midi.shape[1])
    min_ret_y = min(min_ret_y, midi.shape[1])
    ylen.append(midi.shape[1])
print("min_ret", min_ret_y)
print("max_ret", max_ret_y)
print(county/93)
print(county_0/93)
# ylen = sorted(ylen)
# print(ylen)
# ys = [y for y in range(len(ylen))]

# plt.plot(ys, ylen)
# plt.show()

min_ret 14140
max_ret 256228
106282.04301075269
88.0


In [116]:
# audio clipped duration (numbers) = 110336
# midi clipped duration (numbers) = 256000
# midi clipped duration (seconds) = 2558.8309266759
# midi time in each slice (seconds) = 2558.8309266759 / 256 = 9.9954333 s = 9995.4333 ms
# 64ms * 156 values = 9984 ms
# We CAN clip to 9984 ms if we'd like, by removing ~11 values.

# We are only clipping the end, and not adding or clipping any values in the pieces.


longest_piece_idx = xlen[-1][0]
print(cqt_final_dict[longest_piece_idx].shape)
clipped_data = np.delete(cqt_final_dict[longest_piece_idx], np.s_[110336:], axis=1)

cqt_final_dict[longest_piece_idx] = clipped_data
print(cqt_final_dict[longest_piece_idx].shape)

(88, 110392)
(88, 110336)


In [117]:
print(midi_final_dict[longest_piece_idx].shape)
clipped_midi = np.delete(midi_final_dict[longest_piece_idx], np.s_[256000:], axis=1)

midi_final_dict[longest_piece_idx] = clipped_midi
print(midi_final_dict[longest_piece_idx].shape)

(88, 256228)
(88, 256000)


In [118]:
# Must run with a kernel of 64GB

max_ret = 110336

audio_largest_curr_shape = (88, max_ret)

for audio_idx, audio_file in cqt_final_dict.items():
    if audio_file.shape[1] < audio_largest_curr_shape[1]:
        pad_width = audio_largest_curr_shape[1] - audio_file.shape[1]
        # print(pad_width)
        cqt_final_dict[audio_idx] = np.pad(audio_file, ((0, 0), (0, pad_width)), 'constant')

In [119]:
# Must run with a kernel of 64GB

max_ret_y = 256000

midi_largest_curr_shape = (88, max_ret_y)

for midi_idx, midi_file in midi_final_dict.items():
    if midi_file.shape[1] < midi_largest_curr_shape[1]:
        pad_width = midi_largest_curr_shape[1] - midi_file.shape[1]
        # print(pad_width)
        midi_final_dict[midi_idx] = np.pad(midi_file, ((0, 0), (0, pad_width)), 'constant')

# Reshaping

In [120]:
# audio_shape = (88,32,8,431)
audio_shape = (88, 256, 431)

for audio_idx, audio_file in cqt_final_dict.items():
    if len(audio_file.shape) == 2:
        reshaped_audio = audio_file.reshape(audio_shape)
        print(reshaped_audio.shape)
        reshaped_audio_2 = np.moveaxis(reshaped_audio, 0, 1)
        print(reshaped_audio_2.shape)
        cqt_final_dict[audio_idx] = reshaped_audio_2

(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 431)
(256, 88, 431)
(88, 256, 

In [121]:
# midi_shape = (88, 32, 8, 1000)
midi_shape = (88, 256, 1000)

for midi_idx, midi_file in midi_final_dict.items():
    if len(midi_file.shape) == 2:
        reshaped_midi = midi_file.reshape(midi_shape)
        print(reshaped_midi.shape)
        reshaped_midi_2 = np.moveaxis(reshaped_midi, 0, 1)
        print(reshaped_midi_2.shape)
        midi_final_dict[midi_idx] = reshaped_midi_2

(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256, 1000)
(256, 88, 1000)
(88, 256

# Training

In [71]:
# Hyperparameters for the model


In [72]:
import torch
from torch.utils.data import DataLoader, TensorDataset

import random
import keras


keys = list(cqt_final_dict.keys())
random.shuffle(keys)


num_keys = len(keys)
train_size = int(0.5 * num_keys)
valid_size = int(0.25 * num_keys)

train_keys = keys[:train_size]
valid_keys = keys[train_size:train_size + valid_size]
test_keys = keys[train_size + valid_size:]

# train_data = torch.tensor([cqt_final_dict[key] for key in train_keys], dtype=torch.float32)
# valid_data = torch.tensor([cqt_final_dict[key] for key in valid_keys], dtype=torch.float32)
# test_data = torch.tensor([cqt_final_dict[key] for key in test_keys], dtype=torch.float32)

# train_data = torch.tensor([value for key, value in reshaped_cqt_dict.items() if key in train_keys], dtype=torch.float32)
# valid_data = torch.tensor([value for key, value in reshaped_cqt_dict.items() if key in valid_keys], dtype=torch.float32)
# test_data = torch.tensor([value for key, value in reshaped_cqt_dict.items() if key in test_keys], dtype=torch.float32)


train_loader = DataLoader(TensorDataset(train_data), batch_size=32, shuffle=True)
valid_loader = DataLoader(TensorDataset(valid_data), batch_size=32)
test_loader = DataLoader(TensorDataset(test_data), batch_size=32)

NameError: name 'reshaped_cqt_dict' is not defined

In [None]:
import torch
import torch.nn as nn

class RecurrentCNN(nn.Module):
    def __init__(self):
        super(RecurrentCNN, self).__init__()
        
        # Convolutional Layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(20, 2), stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=(4, 2))
        
        # LSTM Layers
        self.lstm1 = nn.LSTM(input_size=32, hidden_size=500, batch_first=True, dropout=0.75)
        self.lstm2 = nn.LSTM(input_size=500, hidden_size=200, batch_first=True)

        # Fully Connected Layer
        self.fc = nn.Linear(200, 1) 

    def forward(self, x):
        # Assuming x shape: (batch_size, 32, 360, 84, 5, 1) with CQT and MIDI dimensions as 84 and 128
        x = x.view(-1, 1, 84, 360)  

        # Apply Convolution
        x = self.pool1(torch.relu(self.conv1(x)))
        # LSTM
        x = x.view(x.size(0), -1, x.size(1))  
        
        # LSTM forward pass
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)

        # Fully connected layer 
        x = self.fc(x[:, -1, :])  
        
        return torch.sigmoid(x)