# Jazz melody generation using LSTMs

Using data from the Weimar Jazz Database and based on Jason Brownlee's LSTM text generation tutorial.

Currently this only takes in a single MIDI file containing the melody track; further notebooks will explore multiple MIDI files, harmony mappings, and who knows what else!

### Imports

In [165]:
import datetime
import re

import h5py
import keras
import mido
import numpy as np

### Load the data

In [272]:
# Load the data
# midi_file = mido.MidiFile("../data/midi/ArtPepper_Anthropology_FINAL.mid") # Unquantized
midi_file = mido.MidiFile("../data/midi_quantized/ArtPepper_Anthropology_FINAL.mid") # Quantized
midi_track = midi_file.tracks[0]

### Clean the data

In [273]:
# Get notes only
midi_notes = [msg for msg in midi_track if msg.type=="note_on" or msg.type=="note_off"]
len(midi_notes)
midi_notes[:10]

[<message note_on channel=0 note=65 velocity=104 time=0>,
 <message note_off channel=0 note=65 velocity=104 time=192>,
 <message note_on channel=0 note=63 velocity=109 time=0>,
 <message note_off channel=0 note=63 velocity=109 time=144>,
 <message note_on channel=0 note=58 velocity=103 time=0>,
 <message note_off channel=0 note=58 velocity=103 time=48>,
 <message note_on channel=0 note=61 velocity=104 time=0>,
 <message note_off channel=0 note=61 velocity=104 time=192>,
 <message note_on channel=0 note=63 velocity=114 time=0>,
 <message note_off channel=0 note=63 velocity=114 time=192>]

In [232]:
# len([msg for msg in midi_track if msg.type=="note_on" and msg.time>0])

In [274]:
# Create note on/off pairs
midi_note_pairs = [(midi_notes[i], midi_notes[i+1]) for i,_ in enumerate(midi_notes[:-1])
                    if midi_notes[i].type=="note_on" and midi_notes[i+1].type=="note_off"
                    and midi_notes[i].note == midi_notes[i+1].note]
len(midi_note_pairs)

530

In [275]:
# Normalize note velocities
# TODO: Play with normalizing other parameters
for note_on, note_off in midi_note_pairs:
    note_on.velocity = note_on.velocity - (note_on.velocity % 10)
set([note_on.velocity for note_on, note_off in midi_note_pairs])

{70, 80, 90, 100, 110, 120}

In [276]:
midi_note_pairs[:10]

[(<message note_on channel=0 note=65 velocity=100 time=0>,
  <message note_off channel=0 note=65 velocity=104 time=192>),
 (<message note_on channel=0 note=63 velocity=100 time=0>,
  <message note_off channel=0 note=63 velocity=109 time=144>),
 (<message note_on channel=0 note=58 velocity=100 time=0>,
  <message note_off channel=0 note=58 velocity=103 time=48>),
 (<message note_on channel=0 note=61 velocity=100 time=0>,
  <message note_off channel=0 note=61 velocity=104 time=192>),
 (<message note_on channel=0 note=63 velocity=110 time=0>,
  <message note_off channel=0 note=63 velocity=114 time=192>),
 (<message note_on channel=0 note=58 velocity=100 time=0>,
  <message note_off channel=0 note=58 velocity=106 time=192>),
 (<message note_on channel=0 note=58 velocity=90 time=0>,
  <message note_off channel=0 note=58 velocity=98 time=384>),
 (<message note_on channel=0 note=50 velocity=90 time=1344>,
  <message note_off channel=0 note=50 velocity=90 time=192>),
 (<message note_on channel

In [277]:
# Create note set
# note_events_keys = ("type", "pitch", "velocity", "duration")
# note_events = [(note.type, note.note, note.velocity, note.time) for note in midi_notes]

note_events_keys = ("noteon_pitch", "noteon_velocity", "noteon_time", "noteoff_time") # Don't use note off velocity to shrink possibilities, and don't use note off pitch because it's the same as note on pitch
note_events = [(note_on.note, note_on.velocity, note_on.time, note_off.time)
               for note_on, note_off in midi_note_pairs]

note_set = sorted(list(set(note_events)))
num_note_events = len(note_events)
num_unique_notes = len(note_set)
print("{} unique notes in note set (vs. {} note events in MIDI file)".format(num_unique_notes, num_note_events))
note_set[:10]

301 unique notes in note set (vs. 530 note events in MIDI file)


[(50, 90, 1344, 192),
 (50, 100, 576, 192),
 (51, 100, 0, 96),
 (51, 100, 0, 112),
 (51, 100, 0, 384),
 (52, 110, 0, 192),
 (53, 90, 0, 96),
 (53, 100, 0, 64),
 (53, 100, 0, 96),
 (53, 100, 0, 384)]

In [278]:
# len([note for note in note_set if note[0] == "note_off"])

In [279]:
# Make map for note to integer
note_to_int = dict((n, i) for i, n in enumerate(note_set))
note_to_int

{(50, 90, 1344, 192): 0,
 (50, 100, 576, 192): 1,
 (51, 100, 0, 96): 2,
 (51, 100, 0, 112): 3,
 (51, 100, 0, 384): 4,
 (52, 110, 0, 192): 5,
 (53, 90, 0, 96): 6,
 (53, 100, 0, 64): 7,
 (53, 100, 0, 96): 8,
 (53, 100, 0, 384): 9,
 (53, 110, 576, 48): 10,
 (54, 100, 0, 96): 11,
 (54, 100, 0, 160): 12,
 (55, 90, 0, 64): 13,
 (55, 100, 0, 64): 14,
 (55, 100, 0, 96): 15,
 (55, 100, 0, 144): 16,
 (55, 100, 0, 256): 17,
 (55, 110, 0, 48): 18,
 (55, 110, 528, 96): 19,
 (56, 100, 0, 48): 20,
 (56, 100, 0, 96): 21,
 (56, 110, 0, 96): 22,
 (57, 100, 0, 64): 23,
 (57, 100, 0, 96): 24,
 (57, 100, 0, 160): 25,
 (57, 110, 0, 96): 26,
 (57, 110, 0, 192): 27,
 (58, 90, 0, 96): 28,
 (58, 90, 0, 384): 29,
 (58, 100, 0, 48): 30,
 (58, 100, 0, 64): 31,
 (58, 100, 0, 96): 32,
 (58, 100, 0, 160): 33,
 (58, 100, 0, 192): 34,
 (58, 100, 0, 288): 35,
 (58, 100, 0, 384): 36,
 (58, 110, 0, 48): 37,
 (58, 110, 0, 96): 38,
 (58, 110, 0, 128): 39,
 (58, 110, 0, 160): 40,
 (58, 110, 0, 192): 41,
 (59, 80, 0, 48): 42,

In [280]:
# Make map for integer back to note (we'll need this in the generation phase)
int_to_note = dict((i, n) for i, n in enumerate(note_set))
int_to_note

{0: (50, 90, 1344, 192),
 1: (50, 100, 576, 192),
 2: (51, 100, 0, 96),
 3: (51, 100, 0, 112),
 4: (51, 100, 0, 384),
 5: (52, 110, 0, 192),
 6: (53, 90, 0, 96),
 7: (53, 100, 0, 64),
 8: (53, 100, 0, 96),
 9: (53, 100, 0, 384),
 10: (53, 110, 576, 48),
 11: (54, 100, 0, 96),
 12: (54, 100, 0, 160),
 13: (55, 90, 0, 64),
 14: (55, 100, 0, 64),
 15: (55, 100, 0, 96),
 16: (55, 100, 0, 144),
 17: (55, 100, 0, 256),
 18: (55, 110, 0, 48),
 19: (55, 110, 528, 96),
 20: (56, 100, 0, 48),
 21: (56, 100, 0, 96),
 22: (56, 110, 0, 96),
 23: (57, 100, 0, 64),
 24: (57, 100, 0, 96),
 25: (57, 100, 0, 160),
 26: (57, 110, 0, 96),
 27: (57, 110, 0, 192),
 28: (58, 90, 0, 96),
 29: (58, 90, 0, 384),
 30: (58, 100, 0, 48),
 31: (58, 100, 0, 64),
 32: (58, 100, 0, 96),
 33: (58, 100, 0, 160),
 34: (58, 100, 0, 192),
 35: (58, 100, 0, 288),
 36: (58, 100, 0, 384),
 37: (58, 110, 0, 48),
 38: (58, 110, 0, 96),
 39: (58, 110, 0, 128),
 40: (58, 110, 0, 160),
 41: (58, 110, 0, 192),
 42: (59, 80, 0, 48),

In [281]:
# Split into subsequences
# TODO: Play with sequence lengths (for both input and outputs)
seq_length = 10
data_input = [] # "X"
data_output = [] # "y"
for i in range(num_note_events-seq_length):
    seq_input = note_events[i:i+seq_length]
    seq_output = note_events[i+seq_length]
    data_input.append([note_to_int[note] for note in seq_input])
    data_output.append(note_to_int[seq_output])
num_seqs = len(data_input)
print("{} sequences".format(num_seqs))
print("{} ==> {}".format(data_input[0], data_output[0]))
data_input[:5]

520 sequences
[114, 85, 30, 63, 95, 34, 29, 0, 26, 55] ==> 38


[[114, 85, 30, 63, 95, 34, 29, 0, 26, 55],
 [85, 30, 63, 95, 34, 29, 0, 26, 55, 38],
 [30, 63, 95, 34, 29, 0, 26, 55, 38, 16],
 [63, 95, 34, 29, 0, 26, 55, 38, 16, 32],
 [95, 34, 29, 0, 26, 55, 38, 16, 32, 63]]

In [282]:
# Reshape input sequences into form [samples, time steps, features]
X = np.reshape(data_input, (num_seqs, seq_length, 1))

# Normalize to 0-1 range
X = X / float(num_unique_notes)

# Convert output to one-hot encoding
y = keras.utils.np_utils.to_categorical(data_output)

In [283]:
print(X[0])
print("==>")
print(y[0])

[[ 0.37873754]
 [ 0.28239203]
 [ 0.09966777]
 [ 0.20930233]
 [ 0.31561462]
 [ 0.11295681]
 [ 0.09634551]
 [ 0.        ]
 [ 0.08637874]
 [ 0.18272425]]
==>
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  

### Define the LSTM model

In [284]:
# Remembering what our shape is
"X.shape = {}, y.shape = {}".format(X.shape, y.shape)

'X.shape = (520, 10, 1), y.shape = (520, 301)'

In [285]:
# Define the model
model = keras.models.Sequential()
model.add(keras.layers.LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
model.add(keras.layers.Dropout(0.2))
model.add(keras.layers.LSTM(256))
model.add(keras.layers.Dropout(0.2))
model.add(keras.layers.Dense(y.shape[1], activation="softmax"))
model.compile(loss="categorical_crossentropy", optimizer="adam")

In [286]:
# Setup checkpoints
# TODO: Add datetime to this
checkpoint_path = "weights_{epoch:02d}_{loss:.4f}.hdf5"
checkpoint = keras.callbacks.ModelCheckpoint(checkpoint_path, monitor="loss", verbose=1, save_best_only=True, mode="min")
callbacks = [checkpoint]

In [287]:
# Fit the model (i.e. train the network)!
# TODO: Play with these parameters, of course
num_epochs = 100
batch_size = 32
model.fit(X, y, epochs=num_epochs, batch_size=batch_size, callbacks=callbacks)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x12a58a160>

### Generate output notes

In [288]:
# Load network weights and recompile
weights_filename = "weights_99_0.9724.hdf5" # Using only note ons
weights_filename = "weights_99_1.3571.hdf5" # Using both note ons and note offs
weights_filename = "weights_95_1.4241.hdf5" # Using note on/off pairs
weights_filename = "weights_97_1.4300.hdf5" # Using note on/off pairs without note off velocity
model.load_weights(weights_filename)
model.compile(loss="categorical_crossentropy", optimizer="adam")
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_9 (LSTM)                (None, 10, 256)           264192    
_________________________________________________________________
dropout_9 (Dropout)          (None, 10, 256)           0         
_________________________________________________________________
lstm_10 (LSTM)               (None, 256)               525312    
_________________________________________________________________
dropout_10 (Dropout)         (None, 256)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 301)               77357     
Total params: 866,861
Trainable params: 866,861
Non-trainable params: 0
_________________________________________________________________


In [293]:
# Start with a random seed
seq_in = data_input[np.random.randint(num_seqs)]
[int_to_note[i] for i in seq_in]

[(77, 100, 0, 64),
 (80, 110, 0, 96),
 (79, 110, 0, 96),
 (77, 110, 0, 96),
 (75, 110, 0, 96),
 (72, 100, 0, 96),
 (73, 100, 0, 96),
 (74, 120, 0, 96),
 (70, 110, 0, 96),
 (65, 100, 0, 96)]

In [294]:
seq_in_notes = [int_to_note[i] for i in seq_in]
[dict((note_events_keys[i], note[i]) for i,_ in enumerate(note)) for note in seq_in_notes][0]

{'noteoff_time': 64,
 'noteon_pitch': 77,
 'noteon_time': 0,
 'noteon_velocity': 100}

In [295]:
# Generate the notes!
num_notes_to_generate = 100
notes_out = []

for i in range(num_notes_to_generate):
    # Reshape and normalize
    x = np.reshape(seq_in, (1, len(seq_in), 1)) # Reshape
    x = x / float(num_unique_notes) # Normalize
    
    # Make the prediction
    pred = model.predict(x, batch_size=batch_size, verbose=0)
    
    # Get output note
    note_idx = np.argmax(pred)
    note = int_to_note[note_idx]
    
    # Add output note to list
    notes_out.append(note)
    
    # Add output note to input sequence, and move forward by one note
    seq_in.append(note_idx) 
    seq_in = seq_in[1:len(seq_in)]

notes_out[:20]

[(76, 110, 528, 192),
 (78, 120, 288, 96),
 (76, 100, 0, 96),
 (72, 110, 0, 96),
 (76, 110, 0, 192),
 (77, 110, 0, 48),
 (74, 120, 0, 96),
 (75, 110, 0, 192),
 (74, 110, 0, 96),
 (77, 110, 0, 128),
 (65, 100, 0, 96),
 (64, 110, 0, 96),
 (64, 110, 0, 96),
 (74, 110, 688, 144),
 (63, 110, 0, 64),
 (67, 110, 0, 192),
 (66, 110, 0, 192),
 (65, 90, 0, 96),
 (72, 110, 480, 192),
 (69, 100, 0, 96)]

In [296]:
# Convert the sequence of note tuples into a sequence of MIDI notes, and then write to MIDI file

# Create MIDI file and track
midi_file_out = mido.MidiFile()
midi_track_out = mido.MidiTrack()
midi_file_out.tracks.append(midi_track_out)

# Append "headers" (track name, tempo, key, time signature)
for message in midi_track[:4]:
    midi_track_out.append(message)

# Add notes
prev_time = 0
prev_note = 0
for note in notes_out:
    ## Note ons only
    #curr_time = prev_time + note[2]
    #prev_note = note[0]
    #prev_time = curr_time
    #message_noteoff = mido.Message("note_off", note=prev_note, velocity=0, time=curr_time) # Prev note off
    #message_noteon = mido.Message("note_on", note=note[0], velocity=note[1], time=curr_time) # Curr note on
    #midi_track_out.append(message_noteoff)
    #midi_track_out.append(message_noteon)
    
    ## Note ons and note offs 
    #curr_time = prev_time + note[3] if note[0]=="note_on" else prev_time
    #curr_time = prev_time + note[3]
    #prev_time = curr_time
    #message = mido.Message(note[0], note=note[1], velocity=note[2], time=curr_time)
    #midi_track_out.append(message)
    
    # Note on/off pairs
    note = dict((note_events_keys[i], note[i]) for i,_ in enumerate(note))
    curr_time_noteon = prev_time + note["noteon_time"]
    curr_time_noteoff = prev_time + note["noteoff_time"]
    #prev_time = curr_time_noteoff
    message_noteon = mido.Message("note_on", note=note["noteon_pitch"], velocity=note["noteon_velocity"], time=curr_time_noteon)
    message_noteoff = mido.Message("note_off", note=note["noteon_pitch"], velocity=note["noteon_velocity"], time=curr_time_noteoff)
    midi_track_out.append(message_noteon)
    midi_track_out.append(message_noteoff)
    
# Save file to disk
filename_out = str(datetime.datetime.now())
filename_out = re.sub("\W+", "", filename_out)
filename_out = "../data/out_{}.mid".format(filename_out)
midi_file_out.save(filename_out)

for message in midi_track_out[4:20]:
    print(message)

note_on channel=0 note=76 velocity=110 time=528
note_off channel=0 note=76 velocity=110 time=192
note_on channel=0 note=78 velocity=120 time=288
note_off channel=0 note=78 velocity=120 time=96
note_on channel=0 note=76 velocity=100 time=0
note_off channel=0 note=76 velocity=100 time=96
note_on channel=0 note=72 velocity=110 time=0
note_off channel=0 note=72 velocity=110 time=96
note_on channel=0 note=76 velocity=110 time=0
note_off channel=0 note=76 velocity=110 time=192
note_on channel=0 note=77 velocity=110 time=0
note_off channel=0 note=77 velocity=110 time=48
note_on channel=0 note=74 velocity=120 time=0
note_off channel=0 note=74 velocity=120 time=96
note_on channel=0 note=75 velocity=110 time=0
note_off channel=0 note=75 velocity=110 time=192
