In [None]:
import glob
from multiprocessing.dummy import Pool as ThreadPool
import os
import random

import mido
from mido import Message, MetaMessage, MidiFile, MidiTrack
import numpy
import sklearn.utils
import tqdm

In [None]:
numpy.set_printoptions(threshold=numpy.nan)

In [None]:
random.seed(0)
numpy.random.seed(0)

In [None]:
# midi = MidiFile()
# track = MidiTrack()
# midi.tracks.append(track)
# track.append(Message('note_on', note=50, velocity=100, time=0))
# track.append(Message('note_off', note=50, velocity=100, time=960))
# track.append(Message('note_on', note=51, velocity=100, time=0))
# track.append(Message('note_off', note=51, velocity=100, time=960))
# track.append(Message('note_on', note=52, velocity=100, time=0))
# track.append(Message('note_off', note=52, velocity=100, time=960))
# midi.save('/home/santiago/Projects/ProjectEuterpe/data/test/test1.mid')
# #midi.ticks_per_beat

In [None]:
midi_dir = '/home/santiago/Projects/ProjectEuterpe/data/midi/classical/'

In [None]:
files = sorted(glob.glob(os.path.join(midi_dir, '*.mid')) + glob.glob(os.path.join(midi_dir, '*.midi')))

In [None]:
len(files)

In [None]:
#files

In [None]:
# midis = []
# for file in files:
#     try:
#         midis.append(MidiFile(file))
#     except:
#         print(file)

In [None]:
INPUT_WIDTH = 1 + 8 + 8 + 16
LOOKBACK = 128
DEFAULT_TICKS = 480
DEFAULT_TEMPO = 500000
CORES = 4

In [None]:
def filter_files(midi_files):
    keep = []
    for file in midi_files:
        try:
            midi = MidiFile(file)
            if len(midi.tracks) == 1:
                keep.append(file)
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            pass
    return sorted(keep)

In [None]:
filtered = filter_files(files)

In [None]:
len(filtered)

In [None]:
# def load_midi(midi_file):
#     data = []
#     midi = mido.MidiFile(midi_file)
#     print(midi.ticks_per_beat)
#     for track in midi.tracks:
#         track_data = []
#         for message in track:
#             if message.type in ['note_on', 'note_off']:
#                 track_data.append([1 if message.type == 'note_on' else 0, message.note, message.velocity, int(message.time * 4 * DEFAULT_TICKS / midi.ticks_per_beat)])
#         if track_data:
#             data.append(track_data)
#     assert data
#     return data

In [None]:
def load_midi(midi_file):
    data = []
    midi = mido.MidiFile(midi_file)
    assert len(midi.tracks) == 1
    for message in midi.tracks[0]:
        if message.type in ['note_on', 'note_off']:
            data.append([1 if message.type == 'note_on' else 0, message.note, message.velocity, int(message.time * DEFAULT_TICKS / midi.ticks_per_beat)])
    assert data
    return numpy.array(data, dtype=numpy.uint8)

In [None]:
# def merge_tracks(data):
#     pass

In [None]:
# def augment(data):
#     augmented = []
#     events = len(data)
#     maximum = data.max(axis=0)[1]
#     minimum = data.min(axis=0)[1]
#     transpositions = 128 - (maximum - minimum)
#     for i in range(transpositions):
#         sequence = numpy.copy(data)
#         for j in range(events):
#             sequence[j, 1] = data[j, 1] - minimum + i
#         augmented.append(sequence)
#     return augmented

In [None]:
def augment(data):
    augmented = []
    events = len(data)
    maximum = data.max(axis=0)[1]
    minimum = data.min(axis=0)[1]
    transpositions = 128 - (maximum - minimum)
    for i in range(transpositions):
        sequence = numpy.copy(data)
        for j in range(events):
            sequence[j, 1] = data[j, 1] - minimum + i
        augmented.append(sequence)
    return augmented

In [None]:
def encode(data):
    encoded = []
    for sequence in data:
        encoded_sequence = numpy.zeros((len(sequence), INPUT_WIDTH), dtype=numpy.uint8)
        for i, event in enumerate(sequence):
            encoded_sequence[i, 0] = event[0]
            encoded_sequence[i, 1:9] = [int(x) for x in format(event[1], '08b')]
            encoded_sequence[i, 9:17] = [int(x) for x in format(event[2], '08b')]
            encoded_sequence[i, 17:] = [int(x) for x in format(event[3], '016b')]
        encoded.append(encoded_sequence)
    return encoded

In [None]:
def prepare(data):
    X = []
    Y = []
    for sequence in data:
        for i in range(len(sequence)):
            if i == 0:
                segment = numpy.zeros((1, INPUT_WIDTH), dtype=numpy.uint8)
            else:
                segment = sequence[max(i - LOOKBACK, 0):i, :]
            if len(segment) < LOOKBACK:
                pad = LOOKBACK - len(segment)
                segment = numpy.pad(segment, [(pad, 0), (0, 0)], mode='constant')
            #prepared.append((X, sequence[i, :]))
            X.append(segment)
            Y.append(sequence[i, :])
    X = numpy.array(X, dtype=numpy.uint8)
    Y = numpy.array(Y, dtype=numpy.uint8)
    return X, Y

In [None]:
# def load_all(midi_dir, track_name):
#     X = []
#     Y = []
#     midi_files = sorted(glob.glob(os.path.join(midi_dir, '*.mid')) + glob.glob(os.path.join(midi_dir, '*.midi')))
#     for midi_file in midi_files:
#         try:
#             data = prepare(encode(augment(load_midi(midi_file, track_name))))
#             X.extend(data[0])
#             Y.extend(data[1])
#         except (KeyboardInterrupt, SystemExit):
#             raise
#         except:
#             print("Skipping", midi_file)
#     #random.shuffle(all_data)
#     X = numpy.array(X)
#     Y = numpy.array(Y)
#     return X, Y

In [None]:
def load_all(midi_files):
    X = []
    Y = []
    #midi_files = sorted(glob.glob(os.path.join(midi_dir, '*.mid')) + glob.glob(os.path.join(midi_dir, '*.midi')))
    for i, midi_file in enumerate(midi_files):
#        try:
        data = prepare(encode([load_midi(midi_file)]))
        X.extend(data[0])
        Y.extend(data[1])
#         except (KeyboardInterrupt, SystemExit):
#             raise
#         except:
#             print("Skipping", midi_file)
#        print(i / len(midi_files))
    #random.shuffle(all_data)
    X = numpy.array(X, dtype=numpy.uint8)
    Y = numpy.array(Y, dtype=numpy.uint8)
    return X, Y

In [None]:
# data = load_midi('/home/santiago/Projects/ProjectEuterpe/data/midi/classical/bali.mid')

In [None]:
# data

In [None]:
# new_midi = MidiFile()
# for track in data:
#     new_track = MidiTrack()
#     new_midi.tracks.append(new_track)
#     for event in track:
#         new_track.append(Message('note_on' if event[0] == 1 else 'note_off', note=event[1], velocity=event[2], time=event[3]))

In [None]:
# new_midi.save('/home/santiago/test.mid')

In [None]:
# new_midis = []
# for midi in midis:
#     if len(midi.tracks) == 1:
#         new_midis.append(midi)

In [None]:
# new_midis

In [None]:
# loaded = []
# for midi in new_midis:
#     loaded.append(load_midi(midi.filename))

In [None]:
# filenames = sorted(list(map(lambda midi: midi.filename, new_midis)))

In [None]:
# filenames

In [None]:
filtered

In [None]:
# pool = ThreadPool(CORES)
# data = pool.map(load_all, filtered)
# pool.close() 
# pool.join()
data = load_all(filtered)

In [None]:
data[0].shape, data[0].dtype, data[1].shape, data[1].dtype

In [None]:
#del data

In [None]:
def bits_to_int(bits):
    out = 0
    for bit in bits:
        out = (out << 1) | bit
    return out

In [None]:
def generator(X, Y, batch_size=32, augment=False, shuffle=True, random_seed=0):
    assert len(X) == len(Y)
    while True:
#         if augment: # unfinished
#             for i in tqdm.tqdm(range(len(X))):
#                 note_offset = random.randrange(-6, 7)
#                 time_multiplier = 2 ** random.uniform(-1, 1)
#                 for event in X[i, :, :]:
#                     if numpy.any(event):
#                         event[1:9] = [int(x) for x in format(bits_to_int(event[1:9]) + note_offset, '08b')]
#                         event[17:] = [int(x) for x in format(int(round(bits_to_int(event[17:]) * time_multiplier)), '016b')]
        if shuffle:
            X, Y = sklearn.utils.shuffle(X, Y, random_state=random_seed)
#             p = numpy.random.permutation(len(X))
#             X, Y = X[p], Y[p]
#             del p
        for i in range(0, len(X), batch_size):
            yield X[i:i + batch_size, :, :], Y[i:i + batch_size, :]

In [None]:
gen = generator(data[0], data[1], batch_size=256)

In [None]:
n = next(gen)

In [None]:
def to_midi(data):
    midi = MidiFile()
    track = MidiTrack()
    midi.tracks.append(track)
    #track.append(MetaMessage('set_tempo', tempo=500000, time=0))
    for event in data:
        onoff = 'note_on' if event[0] == 1 else 'note_off'
        note = bits_to_int(event[1:9])
        velocity = bits_to_int(event[9:17])
        time = bits_to_int(event[17:])
        track.append(Message(onoff, note=note, velocity=velocity, time=time))
    return midi

In [None]:
midi = to_midi(data[1])

In [None]:
midi.save('/home/santiago/Projects/ProjectEuterpe/data/midi/all_classical.mid')