In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


from glob import glob
from tqdm import tqdm
from mido import MidiFile
import string
import numpy as np
from nltk.tokenize import sent_tokenize
from itertools import tee
from numpy import array, random
from sys import getsizeof
from pickle import dump, load


# In[2]:


def msg2dict(msg):
    result = dict()
    if 'note_on' in msg:
        on_ = True
    elif 'note_off' in msg:
        on_ = False
    else:
        on_ = None
    result['time'] = int(msg[msg.rfind('time'):].split(' ')[0].split('=')[1].translate(
        str.maketrans({a: None for a in string.punctuation})))
    if on_ is not None:
        for k in ['note', 'velocity']:
            result[k] = int(msg[msg.rfind(k):].split(' ')[0].split('=')[1].translate(
                str.maketrans({a: None for a in string.punctuation})))
    return [result, on_]

def switch_note(last_state, note, velocity, on_=True):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of this range will be ignored
    result = [0] * 88 if last_state is None else last_state.copy()
    if 21 <= note <= 108:
        result[note-21] = velocity if on_ else 0
    return result

def get_new_state(new_msg, last_state):
    new_msg, on_ = msg2dict(str(new_msg))
    new_state = switch_note(last_state, note=new_msg['note'], velocity=new_msg['velocity'], on_=on_) if on_ is not None else last_state
    return [new_state, new_msg['time']]

def track2seq(track):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of the id range will be ignored
    result = []
    last_state, last_time = get_new_state(str(track[0]), [0]*88)
    for i in range(1, len(track)):
        new_state, new_time = get_new_state(track[i], last_state)
        if new_time > 0:
            result += [last_state]*new_time
        last_state, last_time = new_state, new_time
    return result

def midi2note(num: int) -> str:
    notes = ["C", "Cb", "D", "Eb", "E", "F", "Gb", "G", "Ab", "A", "Bb", "B"]
    return notes[num % 12] + str(num // 12 - 1)


# In[3]:


def update(d, keys, value):
    for key in keys:
        if key in d:
            d = d[key]
        else:
            newd = dict()
            d[key] = newd
            d = newd
    d[value] = d.get(value, 0) + 1


def marginalize(text, window, d=dict()):
    for w in slide(text.split(' '), window):
        features = w[:-1]
        target = w[-1]
        update(d, features, target)
    return d


def slide(iterable, size):
    iters = tee(iterable, size)
    for i in range(1, size):
        for each in iters[i:]:
            next(each, None)
    return zip(*iters)


# In[4]:


window = 3
model = dict()

pbar = tqdm(glob('selected_midis/*.mid'))
for path in pbar:
    try:
        midi = MidiFile(path, clip=True)
        corpus = ''
        for track in midi.tracks:
            arr = np.array(track2seq(track))
            for i in range(arr.shape[0]):
                note = np.where(arr[i] > 0)
                if len(note[0]):
                    corpus += midi2note(note[0][0]) + ' '
        model = marginalize(corpus, window, model)
        pbar.set_postfix({'model (kb)': getsizeof(model)/1e3})
    except:
        pass

dump(model, open('pokemodel4.pkl', 'wb'))

100%|██████████| 619/619 [1:08:02<00:00,  6.59s/it, model (kb)=4.7]


In [20]:
music = {}

for k, v in model.items():
    n1 = midi2note(int(k))
    count = 0
    for note, prob in v.items():
        if note != k and note:
            count += prob
            music[str(n1)] = {}
    for note, prob in v.items():
        if note != k and note:
            n2 = midi2note(int(note))
            music[str(n1)][str(n2)] = prob / count

In [1]:
from pickle import load
model = load(open('pokemodel.pkl', 'rb'))

In [10]:
def midi2note(num: int) -> str:
    notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
    return str(notes[num % 12]) + str(num // 12 - 1)

In [3]:
model

{'D3': {'D3': {'D3': 4217043,
   'Ds3': 2144,
   'Cs3': 1779,
   'F3': 1137,
   'E3': 1902,
   'G3': 983,
   'A3': 433,
   'Fs3': 451,
   'C3': 2074,
   'B2': 807,
   'G2': 649,
   'B3': 92,
   'C4': 68,
   'Gs2': 136,
   'Gs3': 205,
   'D2': 185,
   'A2': 939,
   'As2': 1207,
   'E4': 20,
   'D4': 293,
   'E5': 2,
   'Fs5': 2,
   'Gs1': 7,
   'G0': 20,
   'As3': 128,
   'As0': 9,
   'E2': 188,
   'F2': 254,
   'G4': 6,
   'Ds2': 81,
   'Fs2': 342,
   'Gs4': 4,
   'C2': 25,
   'Cs4': 19,
   'Cs2': 32,
   'G1': 39,
   'F5': 2,
   'D1': 25,
   'A1': 33,
   'A0': 33,
   'Ds4': 18,
   '': 12,
   'Ds1': 7,
   'As1': 96,
   'Fs1': 7,
   'F0': 35,
   'D5': 242,
   'C1': 3,
   'B1': 18,
   'F1': 28,
   'E1': 10,
   'As4': 4,
   'A-1': 14,
   'A4': 2,
   'Fs4': 5,
   'Ds0': 38,
   'Cs1': 10,
   'Gs-1': 1,
   'Fs0': 2,
   'G5': 1,
   'C0': 11,
   'F4': 17,
   'C5': 2,
   'Ds-1': 2,
   'B0': 6,
   'D0': 2,
   'B4': 2,
   'G-1': 2,
   'Cs5': 1,
   'Ds5': 1},
  'Ds3': {'Ds3': 2143, 'C3': 1},
  'Cs3