# Multi-layer LSTM for Music Generation (with Magenta)

In [41]:
import os
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Bidirectional, BatchNormalization, Dropout, Dense, Activation, Lambda, Softmax
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint
from magenta.scripts.convert_dir_to_note_sequences import convert_directory
from note_seq import music_pb2

small_file = 'small_sequences.tfrecord'
large_file = 'large_sequences.tfrecord'
note_divisions = 16

In [42]:
def create_network(network_input, n_vocab):
    """ Define constants """
    hidden_layers = 256
    dropout = 0.4
    temp = 0.6
    
    """ Initializing model """
    model = Sequential()
    
    """ Adding LSTM Layers to Model """
    model.add(
        Bidirectional(
            LSTM(
                hidden_layers,
                dropout=dropout,
                return_sequences=True,
                input_shape=(network_input.shape[1], network_input.shape[2])
            )
        )
    )
    model.add(
        Bidirectional(
            LSTM(
                hidden_layers,
                dropout=dropout,
                return_sequences=True
            )
        )
    )
    model.add(
        Bidirectional(
            LSTM(
                hidden_layers,
                dropout=dropout
            )
        )
    )
    
    """ Add other layers after LSTM """
    model.add(BatchNormalization())
    model.add(Dropout(dropout))
    model.add(Dense(hidden_layers // 2))
    model.add(Activation('relu'))
    model.add(BatchNormalization())
    model.add(Dropout(dropout))
    model.add(Dense(n_vocab))
    model.add(Lambda(lambda x: x / temp))
    model.add(Softmax())
    
    """ Define the optimizer and loss function for the model """
    model.compile(optimizer='adam', loss='categorical_crossentropy')
    
    return model
    

In [43]:
def convert_files():
    convert_directory(os.path.join(os.getcwd(), 'samples', 'small'), small_file, True)
    convert_directory(os.path.join(os.getcwd(), 'samples', 'large'), large_file, True)

In [44]:
# vocab key format "flag pitch"
# flag 0 is next channel, flag 1 is new note, flag 2 is next step, flag 3 is continue note

def is_start(target, start):
    return target + 0.02 > start and target - 0.02 < start

def is_note_valid(target, note, instrument_infos):
    return (
        target - 0.02 < note.end_time and
        target + 0.02 > note.start_time and
        instrument_infos[note.instrument].name != 'no'
    )

def get_vocab():
    reader = tf.data.TFRecordDataset(small_file)
    vocab = set(['0 0', '2 0'])
    for sequence in reader:
        data = music_pb2.NoteSequence.FromString(sequence.numpy())
        total_time = data.time_signatures[1].time
        iteration = 0
        while iteration / note_divisions < total_time:
            continue_notes = [
                note for note in data.notes
                if (not is_start(iteration / note_divisions, note.start_time) and
                is_note_valid(iteration / note_divisions, note, data.instrument_infos))
            ]
            new_notes = [note for note in data.notes if is_start(iteration / note_divisions, note.start_time)]
            for note in continue_notes:
                vocab.add(str(3) + ' ' + str(note.pitch))
            for note in new_notes:
                vocab.add(str(1) + ' ' + str(note.pitch))
            iteration += 1
    pickle.dump(vocab, open('vocab.p', 'wb'))
    return vocab

def get_notes_nvocab():
    vocab = get_vocab()
    vocab_dict = {k: v for v, k in enumerate(vocab)}
    reader = tf.data.TFRecordDataset(small_file)
    notes = []
    for sequence in reader:
        data = music_pb2.NoteSequence.FromString(sequence.numpy())
        total_time = data.time_signatures[1].time
        iteration = 0
        while iteration / note_divisions < total_time:
            notes.append(vocab_dict['0 0'])
            all_notes = [
                note for note in data.notes
                if is_note_valid(iteration / note_divisions, note, data.instrument_infos)
            ]
            all_notes.sort(key=lambda note: note.pitch, reverse=True)
            all_notes.sort(key=lambda note: note.instrument)
            prev_instrument = 0
            for note in all_notes:
                if note.instrument != prev_instrument:
                    if prev_instrument == 'p1':
                        if note.instrument == 'tr':
                            notes.append(vocab_dict['2 0'])
                        notes.append(vocab_dict['2 0'])
                    else:
                        notes.append(vocab_dict['2 0'])
                    prev_instrument = note.instrument
                if is_start(iteration / note_divisions, note.start_time):
                    notes.append(vocab_dict[str(1) + ' ' + str(note.pitch)])
                else:
                    notes.append(vocab_dict[str(3) + ' ' + str(note.pitch)])
            iteration += 1
    return (notes, len(vocab))

def prepare_sequences():
    """ Prepare the sequences used by the Neural Network """
    notes, n_vocab = get_notes_nvocab()
    sequence_length = 4 

    network_input = []
    network_output = []

    for i in range(0, len(notes) - sequence_length, 1):
        network_input.append(notes[i:i + sequence_length])
        network_output.append(notes[i + sequence_length])

    n_patterns = len(network_input)

    # reshape the input into a format compatible with LSTM layers
    network_input = np.reshape(network_input, (n_patterns, sequence_length, 1))
    
    # normalize input
    network_input = network_input / float(n_vocab)

    network_output = to_categorical(network_output)

    return (network_input, network_output, n_vocab)

In [45]:
def train_network():
    """ Train a Neural Network to generate music """
    
    network_input, network_output, n_vocab = prepare_sequences()
    
    print('Done preparing sequences!')
    
    model = create_network(network_input, n_vocab)
 
    checkpoint = ModelCheckpoint(
        "weights.hdf5",
        monitor='loss',
        verbose=0,
        save_best_only=True,
        mode='min'
    )
    
    callbacks_list = [checkpoint]

    # Your line of code here
    model.fit(x=network_input, y=network_output, epochs=50, callbacks=callbacks_list)
    
    model.summary()

train_network()

[pitch: 62
velocity: 6
start_time: 0.1458049886621315
end_time: 0.9275283446712018
program: 80
, pitch: 62
velocity: 10
start_time: 0.04702947845804989
end_time: 0.8278004535147393
instrument: 1
program: 81
, pitch: 50
velocity: 1
start_time: 0.44569160997732427
end_time: 1.643764172335601
instrument: 2
program: 38
]
[pitch: 58
velocity: 3
start_time: 0.00022675736961451248
end_time: 0.7987074829931973
program: 80
, pitch: 58
velocity: 2
start_time: 0.39956916099773243
end_time: 1.1983673469387754
instrument: 1
program: 81
, pitch: 54
velocity: 1
start_time: 0.005215419501133787
end_time: 0.7864625850340137
instrument: 2
program: 38
]
[pitch: 93
velocity: 6
start_time: 0.6010204081632653
end_time: 0.6083219954648527
program: 80
, pitch: 89
velocity: 6
start_time: 0.617482993197279
end_time: 0.6249886621315193
program: 80
, pitch: 86
velocity: 6
start_time: 0.6083219954648527
end_time: 0.6166666666666667
program: 80
, pitch: 86
velocity: 6
start_time: 0.6343083900226758
end_time: 0.6416

[pitch: 67
velocity: 15
start_time: 0.004149659863945578
end_time: 0.9321541950113379
program: 80
, pitch: 63
velocity: 15
start_time: 0.004149659863945578
end_time: 0.9325170068027211
instrument: 1
program: 81
]
[pitch: 83
velocity: 15
start_time: 0.004149659863945578
end_time: 1.0041496598639457
program: 80
, pitch: 55
velocity: 1
start_time: 0.6
end_time: 0.7000000000000001
instrument: 1
program: 38
]
[pitch: 66
velocity: 9
start_time: 0.00022675736961451248
end_time: 0.66562358276644
program: 80
, pitch: 66
velocity: 6
start_time: 0.0006802721088435375
end_time: 0.9985034013605443
instrument: 1
program: 81
, pitch: 54
velocity: 1
start_time: 0.005215419501133787
end_time: 2.663061224489796
instrument: 2
program: 38
]
[pitch: 59
velocity: 1
start_time: 0.55
end_time: 0.637482993197279
instrument: 2
program: 38
]
[pitch: 50
velocity: 1
start_time: 0.4159637188208617
end_time: 3.1951473922902496
program: 80
, pitch: 50
velocity: 1
start_time: 0.01668934240362812
end_time: 3.1957369614

KeyboardInterrupt: 