In [None]:
# 0.0 import packages

import sys
import string

import time
import numpy as np
import pandas as pd

import matplotlib.pyplot as plot

In [None]:
# midi mappings from groove dataset
# kick drum
BASS = 36
# snare drum
SNARE_HEAD = 38
SNARE_RIM = 40
SNARE_X_STICK = 37
# toms
TOM_1_HEAD = 48
TOM_1_RIM = 50
TOM_2_HEAD = 45
TOM_2_RIM = 47
TOM_3_HEAD = 43
TOM_3_RIM = 58
# hi-hats
HH_OPEN_BOW = 46
HH_OPEN_EDGE = 26
HH_CLOSED_BOW = 42
HH_CLOSED_EDGE = 22
HH_PEDAL = 44
# crash cymbal
CRASH_1_BOW = 49
CRASH_1_EDGE = 55
CRASH_2_BOW = 57
CRASH_2_EDGE = 52
# ride cymbal
RIDE_BOW = 51
RIDE_EDGE = 59
RIDE_BELL = 53

# mappings for our own training (9)
KICK = 0
SNARE = 1
HH_OPEN = 2
HH_CLOSED = 3
RIDE = 4
TOM_1 = 5
TOM_2 = 6
TOM_3 = 7
CRASH = 8

# groove mappings to our mappings
KICK_LIST = [BASS]
SNARE_LIST = [SNARE_HEAD, SNARE_RIM, SNARE_X_STICK]
HH_OPEN_LIST = [HH_OPEN_BOW, HH_OPEN_EDGE]
HH_CLOSED_LIST = [HH_CLOSED_BOW, HH_CLOSED_EDGE, HH_PEDAL]
RIDE_LIST = [RIDE_BOW, RIDE_EDGE, RIDE_BELL]
TOM_1_LIST = [TOM_1_HEAD, TOM_1_RIM]
TOM_2_LIST = [TOM_2_HEAD, TOM_2_RIM]
TOM_3_LIST = [TOM_3_HEAD, TOM_3_RIM]
CRASH_LIST = [CRASH_1_BOW, CRASH_1_EDGE, CRASH_2_BOW, CRASH_2_EDGE]


In [None]:
# #1.1 GPU stuff

# print ("cuda: ", torch.cuda.is_available())
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# print ("current device: ", device)
# print ("count: ", torch.cuda.device_count())

# if torch.cuda.is_available():
#     print ("device name: ", torch.cuda.get_device_name(0))
#     torch.cuda.set_device(0)

In [None]:
# 1.2 load groove dataset
import math

groove_csv = pd.read_csv('groove/info.csv')
print("groove dataset:", len(groove_csv))

# get train, test, and validation sets
train_csv = []
test_csv = []
validation_csv = []

for index, row in groove_csv.iterrows():
    if str(row.audio_filename).lower() != "nan":
        split = row['split']
        if split == "train":
            train_csv.append(row)
        elif split == "test":
            test_csv.append(row)
        elif split == "validation":
            validation_csv.append(row)
        
print ("train: ", len(train_csv))
print ("test: ", len(test_csv))
print ("validation: ", len(validation_csv))

print (train_csv[0].midi_filename)

In [None]:
import mido

# code to convert midi file to array
# https://medium.com/analytics-vidhya/convert-midi-file-to-numpy-array-in-python-7d00531890c
def msg2dict(msg):
    result = dict()
    if 'note_on' in msg:
        on_ = True
    elif 'note_off' in msg:
        on_ = False
    else:
        on_ = None
    result['time'] = int(msg[msg.rfind('time'):].split(' ')[0].split('=')[1].translate(
        str.maketrans({a: None for a in string.punctuation})))

    if on_ is not None:
        for k in ['note', 'velocity']:
            result[k] = int(msg[msg.rfind(k):].split(' ')[0].split('=')[1].translate(
                str.maketrans({a: None for a in string.punctuation})))
    return [result, on_]

def switch_note(last_state, note, velocity, on_=True):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of this range will be ignored
    result = [0] * 88 if last_state is None else last_state.copy()
    if 21 <= note <= 108:
        value = 1 if velocity > 0 else 0
        result[note-21] = value if on_ else 0
    return result

def get_new_state(new_msg, last_state):
    new_msg, on_ = msg2dict(str(new_msg))
    new_state = switch_note(last_state, note=new_msg['note'], velocity=new_msg['velocity'], on_=on_) if on_ is not None else last_state
    return [new_state, new_msg['time']]

def track2seq(track):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of the id range will be ignored
    result = []
    last_state, last_time = get_new_state(str(track[0]), [0]*88)
    for i in range(1, len(track)):
        new_state, new_time = get_new_state(track[i], last_state)
        if new_time > 0:
            result += [last_state]*new_time
        last_state, last_time = new_state, new_time
    return result

def mid2array(mid, min_msg_pct=0.1):
    tracks_len = [len(tr) for tr in mid.tracks]
    min_n_msg = max(tracks_len) * min_msg_pct
    # convert each track to nested list
    all_arys = []
    for i in range(len(mid.tracks)):
        if len(mid.tracks[i]) > min_n_msg:
            ary_i = track2seq(mid.tracks[i])
            all_arys.append(ary_i)
    # make all nested list the same length
    max_len = max([len(ary) for ary in all_arys])
    for i in range(len(all_arys)):
        if len(all_arys[i]) < max_len:
            all_arys[i] += [[0] * 88] * (max_len - len(all_arys[i]))
    all_arys = np.array(all_arys)
    all_arys = all_arys.max(axis=0)
    # trim: remove consecutive 0s in the beginning and at the end
    # sums = all_arys.sum(axis=1)
    # ends = np.where(sums > 0)[0]
    return all_arys #[min(ends): max(ends)]

In [None]:
# convert audio files into tensors
import imp
from scipy import signal
import audiosegment
import librosa
import torch
import shutil
import os

# converts an audio file to tensor
def audio_to_melspec_tensor(wav_file_path, sample_rate=44_100): 
    window_size = 0.025
    window_stride = 0.01
    n_dft = int(sample_rate * window_size)
    n_mels = 128
    win_length = 1024
    hop_length = int(sample_rate * window_stride)
    # load in wav file and remove the mean of the signal
    y, sr = librosa.load(wav_file_path, sr=sample_rate)
    y = y - y.mean()
    y = np.append(y[0],y[1:]-.97*y[:-1])
    # compute mel spectrogram
    stft = librosa.stft(y, n_fft=n_dft, hop_length=hop_length, win_length=win_length, window=signal.hamming)
    spec = np.abs(stft)**2
    mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=n_dft, n_mels=n_mels, fmin=20)
    melspec = np.dot(mel_basis, spec)
    logspec = librosa.power_to_db(melspec, ref=np.max)
    logspec = np.transpose(logspec)
    # plot.imshow(logspec.T, origin='lower', aspect='auto')
    # plot.show()
    # turn into tensor
    logspec_tensor = torch.tensor(logspec)
    return logspec_tensor

def reformat_midi_for_training(midi_array):
    reformated_midi_array = np.zeros((midi_array.shape[0], 9)) # create new array with only 9 possible notes
    for note in range(midi_array.shape[1]): # for every possible note value (0-88) (need to add +21 to get correct note value)
        correct_note = note + 21
        for t in range(midi_array.shape[0]): # for every time step
            if (midi_array[t][note] != 0):
                # check each list
                if correct_note in KICK_LIST:
                    reformated_midi_array[t][KICK] = 1
                elif correct_note in SNARE_LIST:
                    reformated_midi_array[t][SNARE] = 1
                elif correct_note in HH_OPEN_LIST:
                    reformated_midi_array[t][HH_OPEN] = 1
                elif correct_note in HH_CLOSED_LIST:
                    reformated_midi_array[t][HH_CLOSED] = 1
                elif correct_note in RIDE_LIST:
                    reformated_midi_array[t][RIDE] = 1
                elif correct_note in TOM_1_LIST:
                    reformated_midi_array[t][TOM_1] = 1
                elif correct_note in TOM_2_LIST:
                    reformated_midi_array[t][TOM_2] = 1
                elif correct_note in TOM_3_LIST:
                    reformated_midi_array[t][TOM_3] = 1
                elif correct_note in CRASH_LIST:
                    reformated_midi_array[t][CRASH] = 1
    return reformated_midi_array

# shrinks an array down keeping the distance between values propotional
def shrink_array_proportionally(midi_array, target_resize):
    resized_array = np.zeros((target_resize, midi_array.shape[1]))
    ratio = midi_array.shape[0] / target_resize
    # iterate through each timestep
    for t in range(midi_array.shape[0]): 
        # for each value
        for i in range(midi_array.shape[1]):
            value = midi_array[t][i]
            if value > 0:
                t2 = int(t / ratio)
                resized_array[t2][i] = value
    return resized_array

# converts a midi array to a list of tensors
def midi_to_tensors(midi_array, num_tensors):
    midi_tensors = []
    split_midi_arrays = np.array_split(midi_array, num_tensors)
    for i in range(len(split_midi_arrays)):
        midi_tensors.append(torch.tensor(split_midi_arrays[i]))
    return midi_tensors

def get_feats_and_labels_from_csv(csv_index):
    # load in wav file
    audio_file_path = "groove/" + csv_index.audio_filename
    wav_file = audiosegment.from_file(audio_file_path)
    # convert sample width if not set to 2 (16 bits)
    if wav_file.sample_width != 2:
        wav_file = wav_file.set_sample_width(2)
        # print("\tnew sample_width: ", wav_file.sample_width)
        wav_file.export(audio_file_path, format="wav")
    # convert file from stereo to mono if channels > 1
    if wav_file.channels != 1:
        wav_file = wav_file.set_channels(1)
        wav_file.export(audio_file_path, format="wav")
    # cutting and padding
    predefined_length = 9.99 # had to make a bit smaller than 10 sec. because was sizing tensor to 1001 instead of 1000
    tensor_size = 1000 # used for label tensor creation
    diced_wav_files = wav_file.dice(predefined_length, zero_pad=True)
    # get feature tensors
    default_sample_rate = 44100
    target_len = predefined_length * default_sample_rate
    feats_tensors_list = []
    i = 0
    for diced_file in diced_wav_files:
        # pad with zeros if not correct length
        diced_file_len = len(diced_file.to_numpy_array())
        if diced_file_len != target_len:
            zeros = int(target_len - diced_file_len)
            diced_array = np.pad(diced_file.to_numpy_array(), (0, zeros))
            diced_file = audiosegment.from_numpy_array(diced_array, framerate=default_sample_rate)
        # export temp wav file and convert to tensor
        diced_file_path = str(csv_index.id) + "-" + str(i) + ".wav"
        diced_file_path = diced_file_path.replace('/', '-')
        diced_file_path = "temp/" + diced_file_path
        diced_file.export(diced_file_path, format="wav")
        feats_tensor = audio_to_melspec_tensor(diced_file_path, wav_file.frame_rate)
        feats_tensors_list.append(feats_tensor)
        i += 1
        
    # load in midi file
    midi_file_path = "groove/" + csv_index.midi_filename
    midi = mido.MidiFile(midi_file_path)

    # convert midi to arrray
    midi_array = mid2array(midi)
    # f = plot.figure()
    # f.set_figwidth(20)
    # f.set_figheight(10)
    # plot.plot(range(midi_array.shape[0]), np.multiply(np.where(midi_array > 0, 1, 0), range(1, 89)), marker='.', markersize=1, linestyle='')
    # plot.title("midi")
    # plot.show()

    # reformat midi array for training
    midi_array = reformat_midi_for_training(midi_array)
    # f = plot.figure()
    # f.set_figwidth(20)
    # f.set_figheight(10)
    # plot.plot(range(midi_array.shape[0]), np.multiply(np.where(midi_array > 0, 1, 0), range(1, 10)), marker='.', markersize=1, linestyle='')
    # plot.title("reformated midi")
    # plot.ylabel('kick          snare          hh_open          hh_closed          ride          tom1          tom2          tom3          crash')
    # plot.show()

    # resize array to be directly related to duration of the audio file (every second is 100 units)
    target_resize = int(csv_index.duration * 100)
    #print ("current midi array size: ", midi_array.shape[0], " target size: ", target_resize)
    midi_array = shrink_array_proportionally(midi_array, target_resize)
    #print ("shrinked midi_array.shape: ", midi_array.shape)
    # f = plot.figure()
    # f.set_figwidth(20)
    # f.set_figheight(10)
    # plot.plot(range(midi_array.shape[0]), np.multiply(np.where(midi_array > 0, 1, 0), range(1, 10)), marker='.', markersize=1, linestyle='')
    # plot.title("shrinked midi")
    # plot.ylabel('kick          snare          hh_open          hh_closed          ride          tom1          tom2          tom3          crash')
    # plot.show()

    # pad midi array to be the same size as 
    zeros = int((tensor_size * len(feats_tensors_list)) - len(midi_array))
    zeros_pad = np.zeros((zeros, 9))
    midi_array = np.concatenate((midi_array, zeros_pad))
    #print ("midi_array.shape: ", midi_array.shape)

    # create label tensors
    label_tensors_list = midi_to_tensors(midi_array, len(feats_tensors_list))

    # return tensor lists
    return feats_tensors_list , label_tensors_list

# reset temp folder
if os.path.isfile('temp/'):
    shutil.rmtree('temp/', ignore_errors=True)
if not os.path.exists('temp/'):
    os.mkdir('temp/')



i = 0
for index in train_csv:
    print (i, " ", index.audio_filename)
    feats_tensors, label_tensors = get_feats_and_labels_from_csv(index)
    for i in range(len(feats_tensors)):
        print ("\tfeat tensor: ", feats_tensors[i].shape, " label tensor: ", label_tensors[i].shape)

    i += 1