In [1]:
from glob import glob
from tqdm import tqdm
from mido import MidiFile
import string
import numpy as np
from nltk.tokenize import sent_tokenize
from itertools import tee
from numpy import array, random
from sys import getsizeof
from pickle import dump

In [2]:
def msg2dict(msg):
    result = dict()
    if 'note_on' in msg:
        on_ = True
    elif 'note_off' in msg:
        on_ = False
    else:
        on_ = None
    result['time'] = int(msg[msg.rfind('time'):].split(' ')[0].split('=')[1].translate(
        str.maketrans({a: None for a in string.punctuation})))
    if on_ is not None:
        for k in ['note', 'velocity']:
            result[k] = int(msg[msg.rfind(k):].split(' ')[0].split('=')[1].translate(
                str.maketrans({a: None for a in string.punctuation})))
    return [result, on_]

def switch_note(last_state, note, velocity, on_=True):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of this range will be ignored
    result = [0] * 88 if last_state is None else last_state.copy()
    if 21 <= note <= 108:
        result[note-21] = velocity if on_ else 0
    return result

def get_new_state(new_msg, last_state):
    new_msg, on_ = msg2dict(str(new_msg))
    new_state = switch_note(last_state, note=new_msg['note'], velocity=new_msg['velocity'], on_=on_) if on_ is not None else last_state
    return [new_state, new_msg['time']]

def track2seq(track):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of the id range will be ignored
    result = []
    last_state, last_time = get_new_state(str(track[0]), [0]*88)
    for i in range(1, len(track)):
        new_state, new_time = get_new_state(track[i], last_state)
        if new_time > 0:
            result += [last_state]*new_time
        last_state, last_time = new_state, new_time
    return result

In [3]:
def update(d, keys, value):
    for key in keys:
        if key in d:
            d = d[key]
        else:
            newd = dict()
            d[key] = newd
            d = newd
    d[value] = d.get(value, 0) + 1


def marginalize(text, window, d=dict()):
    for w in slide(text.split(' '), window):
        features = w[:-1]
        target = w[-1]
        update(d, features, target)
    return d


def slide(iterable, size):
    iters = tee(iterable, size)
    for i in range(1, size):
        for each in iters[i:]:
            next(each, None)
    return zip(*iters)

In [4]:
window = 2
model = dict()

pbar = tqdm(glob('midi/*.midi'))
for path in pbar:
    midi = MidiFile(path, clip=True)
    corpus = ''
    for track in midi.tracks:
        arr = np.array(track2seq(track))
        for i in range(arr.shape[0]):
            note = np.where(arr[i] > 0)
            if len(note[0]):
                corpus += str(note[0][0]) + ' '
    model = marginalize(corpus, window, model)
    pbar.set_postfix({'model (kb)': getsizeof(model)/1e3})

dump(model, open('model.pkl', 'wb'))

  9%|▉         | 2619/29664 [5:08:06<53:01:41,  7.06s/it, model (kb)=4.7] 


KeySignatureError: Could not decode key with 2 flats and mode 255

In [5]:
model

{'0': {'0': 514638,
  '12': 461,
  '1': 16,
  '40': 3,
  '3': 42,
  '44': 1,
  '2': 90,
  '47': 2,
  '6': 115,
  '4': 90,
  '9': 18,
  '10': 118,
  '5': 68,
  '56': 1,
  '7': 104,
  '11': 22,
  '13': 16,
  '20': 8,
  '25': 4,
  '22': 1,
  '23': 4,
  '15': 15,
  '18': 2,
  '16': 1,
  '48': 13,
  '30': 4,
  '32': 34,
  '24': 97,
  '36': 2,
  '31': 4,
  '14': 12,
  '': 2,
  '17': 2,
  '77': 1,
  '60': 3,
  '19': 2,
  '26': 1,
  '8': 20,
  '38': 2,
  '28': 4,
  '53': 2,
  '39': 1,
  '55': 1,
  '46': 1},
 '12': {'12': 17885099,
  '10': 6765,
  '7': 2240,
  '5': 856,
  '39': 216,
  '49': 33,
  '19': 3273,
  '11': 1541,
  '14': 4818,
  '21': 1523,
  '17': 4166,
  '31': 436,
  '30': 11,
  '27': 222,
  '13': 3992,
  '15': 3644,
  '24': 7785,
  '16': 544,
  '26': 78,
  '55': 156,
  '': 30,
  '34': 308,
  '22': 940,
  '48': 160,
  '43': 211,
  '33': 335,
  '8': 1014,
  '40': 68,
  '3': 59,
  '18': 379,
  '1': 14,
  '9': 2022,
  '20': 408,
  '23': 301,
  '25': 288,
  '36': 757,
  '29': 371,
  '46'