In [284]:
import os
from datetime import datetime
from collections import defaultdict

import numpy as np

from keras import Sequential
from keras.utils import to_categorical
from keras.layers import LSTM, Dense, Dropout, Flatten, Activation

from pyknon.genmidi import Midi
from pyknon.music import NoteSeq, Note
from music21 import midi, stream, converter, note, chord, instrument

import wandb
from wandb.keras import WandbCallback

In [291]:
MIDI_DIR = 'Omnibook/Midi'
SEQ_LENGTH = 30
NUM_EPOCHS = 20

In [3]:
def make_midi(notes, name, filepath):
    notes = [Note(note) for note in notes]
    midi = Midi(1, tempo=90)
    midi.seq_notes(notes, track=0)
    midi.write(filepath)
    
def play_midi(filepath):
    mf = midi.MidiFile()
    mf.open(filepath)
    mf.read()
    mf.close()
    stream = midi.translate.midiFileToStream(mf)
    stream.show('midi')

def load_midi(filepath):
    mf = midi.MidiFile()
    mf.open(filepath)
    mf.read()
    mf.close()
    return mf

def load_midi_dir(path):
    filenames = os.listdir(path)
    filepaths = [os.path.join(path, fn) for fn in filenames]
    return [load_midi(fp) for fp in filepaths]

def get_pitch_range(streams):
    all_pitches = set(pitch for stream in streams for pitch in stream.pitches)
    return min(all_pitches), max(all_pitches)

def get_notes(stream):
    return stream.elements[0].notesAndRests

def get_durations(streams):
    return set(note.duration.quarterLength for stream in streams for note in get_notes(stream))

def build_indexes(pitches):
    ind_to_pitch = dict(enumerate(pitches, 1))
    ind_to_pitch[0] = 'rest'
    pitch_to_ind = {v: k for k, v in ind_to_pitch.items()}
    return pitch_to_ind, ind_to_pitch

In [286]:
class Encoder:
    
    rest_code = 'rest'
    rest_index = 0
    
    def __init__(self, streams, seq_length, augment_transpose=False):
        self.streams = streams
        self.augment_transpose = augment_transpose
        self._seq_length = seq_length
        self._pitches = self._make_pitches()
        self._note_to_index, self._index_to_note = self._build_indexes()
        self.vocab_size = len(self._note_to_index)
        
    def encode_note(self, note):
        if note.isRest:
            return self.rest_index
        return self._note_to_index[note.pitch.midi]
    
    def encode_notes(self, notes):
        return [self.encode_note(note) for note in notes]
    
    def decode_note(self, index):
        note_code = self._index_to_note[index]
        if note_code == self.rest_code:
            return note.Rest(type='eighth')
        return note.Note(note_code, type='eighth')
    
    def make_training_sequences(self):
        training_sequences = []
        labels = []
        for stream in self.streams:
            # get notes for each stream
            notes = get_notes(stream)
            # encode the notes
            note_indices_sequence = self.encode_notes(notes)
            # add augmented sequences
            if self.augment_transpose:
                note_indices_sequences = [
                    self._transpose_encoded_sequence(note_indices_sequence, interval)
                    for interval in range(1, 12)
                ]
            else:
                note_indices_sequences = [note_indices_sequence]
            # chunk sequences into training length
            for note_sequence in note_indices_sequences:
                for index in range(len(note_sequence)-self._seq_length):
                    training_sequences.append(note_sequence[index:index+self._seq_length])
                    labels.append(note_sequence[index+self._seq_length])
        return training_sequences, labels

    def _transpose_encoded_sequence(self, sequence, interval):
        return [
            index + interval if index != self.rest_index else self.rest_index
            for index in sequence
        ]
    
    def _make_pitches(self):
        all_pitches = set(pitch for stream in streams for pitch in stream.pitches)
        min_pitch, max_pitch = min(all_pitches), max(all_pitches)
        if self.augment_transpose:
            return list(range(min_pitch_midi, max_pitch_midi + 12))
        return list(range(min_pitch_midi, max_pitch_midi + 1))
    
    def _build_indexes(self):
        ind_to_note = dict(enumerate(self._pitches, 1))
        ind_to_note[self.rest_index] = self.rest_code
        note_to_ind = {v: k for k, v in ind_to_note.items()}
        return note_to_ind, ind_to_note

Load midi files and convert them to streams.

In [292]:
midi_files = load_midi_dir(MIDI_DIR)
streams = [midi.translate.midiFileToStream(mf) for mf in midi_files]
encoder = Encoder(streams, seq_length=SEQ_LENGTH, augment_transpose=True)
training_data, training_labels = encoder.make_training_sequences()
training_data = to_categorical(training_data, num_classes=encoder.vocab_size)
training_labels = to_categorical(training_labels, num_classes=encoder.vocab_size)

In [293]:
print("Training data shape:", training_data.shape)
print("Training labels shape:", training_labels.shape)

Training data shape: (239129, 30, 44)
Training labels shape: (239129, 44)


In [294]:
model = Sequential()
model.add(LSTM(128, input_shape=training_data.shape[1:], return_sequences=True))
model.add(Dropout(0.5))
model.add(LSTM(128))
model.add(Dense(512))
model.add(Dropout(0.5))
model.add(Dense(encoder.vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')

wandb.init()

W&B Run: https://app.wandb.ai/pvarsh/bop-net/runs/oeuuowov
Call `%%wandb` in the cell containing your training loop to display live results.


W&B Run https://app.wandb.ai/pvarsh/bop-net/runs/oeuuowov

In [None]:
model.fit(
    training_data,
    training_labels,
    epochs=NUM_EPOCHS,
    batch_size=64,
    validation_split=0.1,
    callbacks=[WandbCallback()]
)
model.save(
    'model_{timestamp}_length_{seq_length}_epochs_{num_epochs}'
    .format(
        timestamp=datetime.now().isoformat(),
        seq_length=SEQ_LENGTH,
        num_epochs=NUM_EPOCHS,
    )
)

Train on 215216 samples, validate on 23913 samples
Epoch 1/20
Resuming run: https://app.wandb.ai/pvarsh/bop-net/runs/oeuuowov
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20

In [251]:
def improvise(model, start_input, index_to_note, note_to_index, sequence_length, solo_length):
    start_input_indices = encoder.encode_notes(start_input)
    
    solo = []
    solo.extend(start_input_indices)
    
    vocab_size = len(index_to_note)
    
    for _ in range(solo_length):
        network_input = solo[-sequence_length:]
        network_input = to_categorical(network_input, num_classes=vocab_size)
        network_input = np.reshape(network_input, (1,) + network_input.shape)
        prediction = model.predict(network_input, verbose=False)
        prediction_note_index = int(np.random.choice(prediction.shape[1], 1, p=prediction[0]))
        solo.append(prediction_note_index)
    solo = [encoder.decode_note[ind] for ind in solo]
    return solo

In [256]:
def improvise(model, start_input, solo_length, encoder):
    start_input_indices = encoder.encode_notes(start_input)
    
    solo = []
    solo.extend(start_input_indices)
        
    for _ in range(solo_length):
        network_input = solo[-encoder._seq_length:]
        network_input = to_categorical(network_input, num_classes=encoder.vocab_size)
        network_input = np.reshape(network_input, (1,) + network_input.shape)
        prediction = model.predict(network_input, verbose=False)
        prediction_note_index = int(np.random.choice(prediction.shape[1], 1, p=prediction[0]))
        solo.append(prediction_note_index)
    solo = [encoder.decode_note(ind) for ind in solo]
    return solo

In [261]:
def create_stream(notes):
    s = stream.Stream()
    for note in notes:
        s.append(note)
    return s

In [262]:
solo = improvise(
    model,
    start_input=get_notes(streams[7])[:SEQ_LENGTH*4],
    solo_length=100,
    encoder=encoder,
)

In [263]:
solo_stream = create_stream(solo)
solo_stream.show('midi')

In [265]:
mf = midi.translate.streamToMidiFile(solo_stream)
mf.open('bloomdido_with_rests_augmentation_16_epochs.mid', 'wb')
mf.write()
mf.close()