In [1]:
from wandb.keras import WandbCallback
import wandb
from glob import glob
from tqdm import tqdm
from typing import List
from tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, BatchNormalization, Embedding, Dropout, concatenate, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint
import pretty_midi
import numpy as np
import tensorflow as tf
import yaml
import random
import os

config = yaml.safe_load(open('config.yaml', 'r'))
pathes = config['path']
params = config['params']
outputs = config['output']
params['beat_duration'] = 1/(params['bpm']/60)

In [12]:
class NoteData(pretty_midi.Note):

    def __init__(self, velocity, pitch, start, end, sustain):
        super().__init__(velocity, pitch, start, end)
        self.sustain = 1 if sustain else 0

    @staticmethod
    def decode(data, start, velocity=90):
        '''
        Return (`pretty_midi.Note` object, sustain)
        '''
        data = data.split('-')
        pitch = int(data[0])
        duration = float(data[1])
        end = start+duration

        return pretty_midi.Note(velocity, pitch, start, end), int(data[2])

    def encode(self):
        return f'{self.pitch}-{self.get_duration()}-{self.sustain}'

    def __repr__(self):
        # return f'Note(start={self.start}, end={self.end}, pitch={self.pitch}, velocity={self.velocity}, sustain={self.sustain})'
        return f'NoteData(pitch={self.pitch}, duration={self.get_duration()}, sustain={self.sustain}, enc={self.encode()})'


def checkIntervals(beat: List[NoteData], start: float, end: float):
    '''
    Pass a list of encoded beat string. Returns a list of note off intervals.
    '''

    beat = beat.copy()
    beat.insert(0, NoteData(90, 0, None, start, 1))
    beat.append(NoteData(90, 0, end, None, 1))
    intervals = []
    for i in range(1, len(beat)):
        note_time_interval = beat[i].start-beat[i-1].end
        if(note_time_interval > 0):
            intervals.append([beat[i-1].end, beat[i].start])
        elif(note_time_interval < 0):
            print("Error! Minus note interval, is there any overlapping note exist?")

    return intervals


def expandNoteList(notes):

    # First note on time, in second
    midiStartTime = notes[0].start
    # Last note off time, in second
    midiEndTime = notes[-1].end
    # How many beats in this note list. One in 4/4.
    beatNum = int((midiEndTime-midiStartTime)/params['beat_duration'])

    lastPitch = None
    totalNoteDataList = []
    for beatIndex in range(beatNum):
        # Get start and end time in this beat.
        start = midiStartTime + beatIndex*params['beat_duration']
        end = midiStartTime + (beatIndex+1)*params['beat_duration']

        # Find notes that cover this beat interval.
        coveredNotes = [note for note in notes if (note.start < end)]
        coveredNotes = [note for note in coveredNotes if (note.end > start)]

        # encodedBeatList = []
        noteDataList = []
        for note in coveredNotes:
            sustain = (note.pitch != lastPitch or note.start < start)

            # Clip the note start and end time.
            noteStart = start if note.start <= start else note.start
            noteEnd = end if note.end >= end else note.end

            noteData = NoteData(90, note.pitch, noteStart, noteEnd, sustain)
            # encodedNote = encode(note.pitch, noteEnd-noteStart, sustain)
            noteDataList.append(noteData)
            lastPitch = note.pitch

        intervals = checkIntervals(noteDataList, start, end)
        if intervals:
            for interval in intervals:
                noteData = NoteData(90, 0, interval[0], interval[1], 1)
                noteDataList.append(noteData)

        noteDataList.sort(key=lambda x: x.start)

        totalNoteDataList.append(noteDataList)

    return totalNoteDataList


def encodeExpandedNoteList(expandedNoteDataList: List[List[NoteData]]):
    encodedBeatList = []
    for expandedNoteData in expandedNoteDataList:
        encodedBeatList.append('#'.join([noteData.encode() for noteData in expandedNoteData]))
    return encodedBeatList


def decodeBeatList(encodedBeatList):

    timestamp = 0
    decodedNoteDataList = []
    for beat in encodedBeatList:
        encodedNotes = beat.split('#')
        for encodedNote in encodedNotes:
            decodedNote, sustain = NoteData.decode(encodedNote, timestamp)
            decodedNoteDataList.append((decodedNote, sustain))
            timestamp += decodedNote.get_duration()

    # print(decodedNoteDataList)
    # return

    decodedNoteList = []
    lastPitchAndSustain = (-1, -1)
    for decodedNote, sustain in decodedNoteDataList:
        if decodedNote.pitch == 0:
            pass
        elif decodedNote.pitch == lastPitchAndSustain[0] and sustain:
            decodedNoteList[-1].end = decodedNote.end
        else:
            decodedNoteList.append(decodedNote)

        lastPitchAndSustain = (decodedNote.pitch, sustain)

    return decodedNoteList


def generateMidi(filename: str, tracks: List[List[str]], programName: str = 'Acoustic Grand Piano'):
    '''
    tracks List[List[str]]:
        Encoded beat sequence. [['70-0.5-1', '68-0.25-1#70-0.25-1'..], ...]

    Write the midi file from encoded tracks.

    '''
    midi = pretty_midi.PrettyMIDI()
    for index, track in enumerate(tracks):
        _instrument = pretty_midi.Instrument(
            pretty_midi.instrument_name_to_program(programName), False, f"Track {index+1}")
        _instrument.pitch_bends.append(pretty_midi.PitchBend(0, 0))
        _instrument.notes = decodeBeatList(track)
        midi.instruments.append(_instrument)

    midi.write(filename)


def loadMidiInFolder(path):
    '''
    path:
        path for glob to grab midi. Must include last file name like `*.mid`.

    Return tracks read from midi, Note has been encoded to string.
    '''
    wordsList = []
    for mid in tqdm(glob(path, recursive=True)):
        try:
            midiData = pretty_midi.PrettyMIDI(mid)
        except:
            print("Cannot load this midi")
            continue

        tracks = [instrument.notes for instrument in midiData.instruments]
        encodedTracks = [encodeExpandedNoteList(expandNoteList(track)) for track in tracks]
        wordsList.append(encodedTracks)

    return wordsList


def concatMidi(data):
    '''
    data:
        trainDataRaw or testDataRaw, will concatenate each track.

    Returns np.array (4, total_length)
    '''

    concatData = None
    for mid in data:
        if concatData == None:
            concatData = mid
        else:
            for i in range(len(concatData)):
                concatData[i] += mid[i]

    return np.array(concatData)


def __split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = tf.one_hot(chunk[1:], params['current_vocab_size'])
    return input_text, target_text


def makeDataset(data, track):
    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.batch(params['sequence_length']+1, drop_remainder=True)
    params['current_vocab_size'] = params['vocab_size'][track]
    dataset = dataset.map(__split_input_target)
    del params['current_vocab_size']
    dataset = dataset.shuffle(10000).batch(params['batch_size'], drop_remainder=True)
    return dataset


def makeSingleTrackModel(track, batch_size=params['batch_size']):
    # Model for single track

    model = Sequential()
    model.add(Embedding(params['vocab_size'][track], output_dim=params['embed_size'],
                        batch_input_shape=(batch_size, None)))
    for _ in range(params['layers']):
        model.add(LSTM(params['unit'], return_sequences=True, stateful=params['stateful'],
                       dropout=params['dropout'], recurrent_dropout=params['dropout']))

    model.add(Dense(params['unit'], activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(params['dropout']))
    model.add(Dense(params['vocab_size'][track], activation='softmax'))
    model.compile(loss=CategoricalCrossentropy(from_logits=False),
                  optimizer=Adam(learning_rate=0.0001), metrics=['acc'])

    return model

def makeLSTMLayer(name, track, batch_size=params['batch_size']):
    inputLayer = Input(batch_input_shape=([batch_size, None]))
    embedLayer = Embedding(params['vocab_size'][track], params['embed_size'], name=name+'-embed')(inputLayer)
    lstmLayer = LSTM(params['unit'], return_sequences=True, stateful=params['stateful'],
                     dropout=params['dropout'], recurrent_dropout=params['dropout'], name=name+'-LSTM1')(embedLayer)
    lstmLayer = LSTM(params['unit'], return_sequences=True, stateful=params['stateful'],
                     dropout=params['dropout'], recurrent_dropout=params['dropout'], name=name+'-LSTM2')(lstmLayer)

    return inputLayer, lstmLayer

def makeOutputLayer(name, layer, track):
    outputLayer = layer
    # outputLayer = Dense(params['unit'], activation='relu', name=name+'-linear2')(layer)
    # outputLayer = BatchNormalization()(outputLayer)
    # outputLayer = Dropout(params['dropout'])(outputLayer)
    outputLayer = Dense(params['vocab_size'][track], activation='softmax', name=name+'-output')(outputLayer)

    return outputLayer

def makeMultiTrackModel(batch_size=params['batch_size']):

    inputLayer1, lstmLayer1 = makeLSTMLayer('input1', 0, batch_size)
    inputLayer2, lstmLayer2 = makeLSTMLayer('input2', 1, batch_size)
    inputLayer3, lstmLayer3 = makeLSTMLayer('input3', 2, batch_size)
    inputLayer4, lstmLayer4 = makeLSTMLayer('input4', 3, batch_size)

    concatLayers = concatenate([lstmLayer1, lstmLayer2, lstmLayer3, lstmLayer4])

    x = LSTM(params['unit'], return_sequences=True, name='concatLSTMLayer1')(concatLayers)
    # x = Dropout(params['dropout'])(x)
    # x = LSTM(params['unit'], return_sequences=True, name='concatLSTMLayer2')(x)
    x = Dense(params['unit'], activation='relu', name='concatLinear1')(x)
    x = BatchNormalization()(x)
    x = Dropout(params['dropout'])(x)

    outputLayers = [makeOutputLayer('output'+str(i), x, i) for i in range(params['track_size'])]

    model = Model(inputs=[inputLayer1, inputLayer2, inputLayer3, inputLayer4], outputs=outputLayers)

    model.compile(loss=CategoricalCrossentropy(from_logits=False),
                  optimizer=Adam(learning_rate=0.0001), metrics=['acc'])
    
    return model


In [3]:
trainDataRaw = loadMidiInFolder(pathes['data_dir'])
testDataRaw = loadMidiInFolder(pathes['test_dir'])

100%|██████████| 500/500 [00:03<00:00, 143.07it/s]
100%|██████████| 150/150 [00:01<00:00, 127.90it/s]


In [4]:
trainData = concatMidi(trainDataRaw)
testData = concatMidi(testDataRaw)

In [5]:
# Create dictionaries for each track or collect all (now).

beat2idx = []
idx2beat = []
params['vocab_size'] = []

vocabularies = set()

for track in range(len(trainData)):
    vocabularies = vocabularies | set(trainData[track]) | set(testData[track])
for track in range(len(trainData)):
    params['vocab_size'].append(len(vocabularies))
    beat2idx.append({beat: i for i, beat in enumerate(vocabularies)})
    idx2beat.append({idx:beat for beat, idx in beat2idx[track].items()})

In [None]:
params['vocab_size']

In [6]:
for track in range(len(trainData)):

    trainData[track] = np.array([beat2idx[track][beat] for beat in trainData[track]])
    testData[track] = np.array([beat2idx[track][beat] for beat in testData[track]])

trainData = trainData.astype(int)
testData = testData.astype(int)

In [7]:
trainDatasets = []
testDatasets = []
for track in range(len(trainData)):
    trainDatasets.append(makeDataset(trainData[track], track))
    testDatasets.append(makeDataset(testData[track], track))

In [8]:
trainDatasets[0]

<BatchDataset shapes: ((16, 16), (16, 16, 1499)), types: (tf.int64, tf.float32)>

## Single track training

In [None]:
name = 'single-track-lstm-training-2'
track = 0
model = makeSingleTrackModel(track)

In [None]:
callbacks = []
filepath = config['path']['model_dir'] + \
    "model-"+name+"-{epoch:02d}-{loss:.4f}.hdf5"
callbacks.append(ModelCheckpoint(
    filepath=filepath,
    monitor='loss',
    mode='min',
    save_weights_only=True,
    save_best_only=True))

if config['output']['wandb']:
    wandb.init(config=params, project='lstm-singletrack-js-fake')
    callbacks.append(WandbCallback(
        log_weights=True, log_evaluation=False, validation_steps=5))

history = model.fit(trainDatasets[track], validation_data=testDatasets[track], batch_size=params['batch_size'],
                    epochs=params['epochs'], verbose=1, callbacks=callbacks)


## Multi track training

In [9]:
x_train = []
y_train = []
x_val = []
y_val = []

for track in range(params['track_size']):

    x = list(trainDatasets[track].map(lambda x, y: x))[0].numpy()
    y = list(trainDatasets[track].map(lambda x, y: y))[0].numpy()

    val_x = list(testDatasets[track].map(lambda x, y: x))[0].numpy()
    val_y = list(testDatasets[track].map(lambda x, y: y))[0].numpy()

    x_train.append(x)
    y_train.append(y)
    x_val.append(val_x)
    y_val.append(val_y)

In [10]:
x_train[0].shape, y_train[0].shape, x_val[0].shape, y_val[0].shape

((16, 16), (16, 16, 1499), (16, 16), (16, 16, 1499))

In [14]:
name = 'multi-track-lstm-training-5'
model = makeMultiTrackModel()



In [15]:
callbacks = []
filepath = config['path']['model_dir'] + \
    "model-"+name+"-{epoch:02d}-{loss:.4f}.hdf5"
callbacks.append(ModelCheckpoint(
    filepath=filepath,
    monitor='loss',
    mode='min',
    save_weights_only=True,
    save_best_only=True))

if outputs['wandb']:
    wandb.init(config=params, project='lstm-multi-track-js-fake')
    callbacks.append(WandbCallback(
        log_weights=True, log_evaluation=False, validation_steps=5))

history = model.fit(x=x_train, y=y_train, validation_data=(x_val, y_val), batch_size=params['batch_size'],
                    epochs=params['epochs'], verbose=1, callbacks=callbacks)

if outputs['wandb']:
    wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mstu00608[0m ([33mtku-cilab[0m). Use [1m`wandb login --relogin`[0m to force relogin




Epoch 1/5000
Epoch 2/5000
Epoch 3/5000
Epoch 4/5000
Epoch 5/5000
Epoch 6/5000
Epoch 7/5000
Epoch 8/5000
Epoch 9/5000
Epoch 10/5000
Epoch 11/5000
Epoch 12/5000
Epoch 13/5000
Epoch 14/5000
Epoch 15/5000
Epoch 16/5000
Epoch 17/5000
Epoch 18/5000
Epoch 19/5000
Epoch 20/5000
Epoch 21/5000
Epoch 22/5000
Epoch 23/5000
Epoch 24/5000
Epoch 25/5000
Epoch 26/5000
Epoch 27/5000
Epoch 28/5000
Epoch 29/5000
Epoch 30/5000
Epoch 31/5000
Epoch 32/5000
Epoch 33/5000
Epoch 34/5000
Epoch 35/5000
Epoch 36/5000
Epoch 37/5000
Epoch 38/5000
Epoch 39/5000
Epoch 40/5000
Epoch 41/5000
Epoch 42/5000
Epoch 43/5000
Epoch 44/5000
Epoch 45/5000
Epoch 46/5000
Epoch 47/5000
Epoch 48/5000
Epoch 49/5000
Epoch 50/5000
Epoch 51/5000
Epoch 52/5000
Epoch 53/5000
Epoch 54/5000
Epoch 55/5000
Epoch 56/5000
Epoch 57/5000
Epoch 58/5000
Epoch 59/5000
Epoch 60/5000
Epoch 61/5000
Epoch 62/5000
Epoch 63/5000
Epoch 64/5000
Epoch 65/5000
Epoch 66/5000
Epoch 67/5000
Epoch 68/5000
Epoch 69/5000
Epoch 70/5000
Epoch 71/5000
Epoch 72/5000
E

VBox(children=(Label(value='87.942 MB of 87.942 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
output0-output_acc,▁▁▃▄▆███████████████████████████████████
output0-output_loss,█▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
output1-output_acc,▁▂▃▅▇███████████████████████████████████
output1-output_loss,█▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
output2-output_acc,▁▂▃▅▇███████████████████████████████████
output2-output_loss,█▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
output3-output_acc,▁▂▃▅▇███████████████████████████████████
output3-output_loss,█▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
best_epoch,578.0
best_val_loss,28.22991
epoch,4999.0
loss,0.00697
output0-output_acc,1.0
output0-output_loss,0.0016
output1-output_acc,1.0
output1-output_loss,0.00165
output2-output_acc,1.0
output2-output_loss,0.0017


## Generate Single Track

In [None]:
initBeatIndex = random.randint(0, len(trainData[0])-params['sequence_length'])
initBeat = np.array([trainData[0, initBeatIndex:initBeatIndex+params['sequence_length']]])

In [None]:
initBeat.shape

In [None]:
weight = 'model-single-track-lstm-training-2-496-1.3643.hdf5'

model = makeSingleTrackModel(track, 1)
model.build(tf.TensorShape([1, None]))
model.load_weights(os.path.join(pathes['model_dir'], weight))
model.reset_states()
pred = model(initBeat)
pred = pred.numpy().squeeze()

music_sequence = []

for i in tqdm(range(outputs['length'])):

    topk = tf.math.top_k(pred[-1], 3)
    topkChoices = topk[1].numpy().squeeze()
    topkValues = topk[0].numpy().squeeze()

    # Apply random
    next_beat = None
    if np.random.uniform(0, 1) < .5:
        next_beat = topkChoices[0]
    else:
        p_choices = tf.math.softmax(topkValues[1:]).numpy()
        next_beat = np.random.choice(topkChoices[1:], 1, p=p_choices)[0]

    music_sequence.append(idx2beat[track][next_beat])

    pred = model(np.array([[next_beat]]))
    pred = tf.expand_dims(pred, 0)

## Generate Multi Track

In [None]:
initBeatIndex = random.randint(0, len(trainData[0])-params['sequence_length'])
initBeat = []
for track in range(params['track_size']):
    initBeat.append(np.array([trainData[track, initBeatIndex:initBeatIndex+params['sequence_length']]]))

In [None]:
weight = 'model-multi-track-lstm-training-4-998-1.0914.hdf5'

In [None]:
model = makeMultiTrackModel(1)
model.build(tf.TensorShape([1, None]))
model.load_weights(os.path.join(pathes['model_dir'], weight))
model.reset_states()
pred = model(initBeat)
pred = [p.numpy().squeeze() for p in pred]

music_sequence = []
for track in range(params['track_size']):
    music_sequence.append([idx2beat[track][idx] for idx in initBeat[track][0].tolist()])

outputs['length'] = 100
for i in tqdm(range(outputs['length'])):

    nextBeatList = []
    for index, p in enumerate(pred):
        topk = tf.math.top_k(p[-1], 3)
        topkChoices = topk[1].numpy().squeeze()
        topkValues = topk[0].numpy().squeeze()

        # Apply random
        # next_beat = topkChoices[0]
        next_beat = None
        if np.random.uniform(0, 1) < .5:
            next_beat = topkChoices[0]
        else:
            p_choices = tf.math.softmax(topkValues[1:]).numpy()
            next_beat = np.random.choice(topkChoices[1:], 1, p=p_choices)[0]

        music_sequence[index].append(idx2beat[index][next_beat])
        nextBeatList.append(np.array([[next_beat]]))

    pred = model.predict(nextBeatList, verbose=0)
    # pred = tf.expand_dims(pred, 0)


In [None]:
generateMidi('new100-4-2'+'.mid', music_sequence)

## Output

In [None]:
generateMidi(name+'.mid', [music_sequence])