In [1]:
# External imports
import os
import pretty_midi
import numpy as np
import tensorflow as tf
import note_seq
import IPython

# Internal imports
from data.helpers.midi import MidiEventProcessor
from models.midi_transformer import MIDITransformer

## Required file paths and common variables

In [2]:
## File paths
BASE_DIR = "/home/richhiey/Desktop/workspace/projects/virtual_musicians"
DATA_DIR = os.path.join(BASE_DIR, "data", "POP909-Dataset", "POP909")
MIDI_EVENTS_PATH = os.path.join(BASE_DIR, "data", "preprocessed",  "pop909-event-token.npy")
DATASET_PATH = os.path.join(BASE_DIR, "data", "preprocessed",  "pop909.tfrecords")
MODEL_SAVE_PATH = os.path.join(BASE_DIR, "cache", 'checkpoints', 'model3')
MODEL_CONFIG_PATH = os.path.join(BASE_DIR, 'model-store', 'models', 'configs', 'default.json')

# Common variables
BATCH_SIZE = 2

## Loading POP909 MIDI Dataset

In [3]:
pop909 = np.load(MIDI_EVENTS_PATH, allow_pickle=True)
print(np.shape(pop909))

melodies = [song['MELODY'] for song in pop909]
rhythms = [song['PIANO'] for song in pop909]
bridges = [song['BRIDGE'] for song in pop909]

if not os.path.exists(DATASET_PATH):
    with tf.io.TFRecordWriter(DATASET_PATH) as file_writer:
        for melody, rhythm, bridge in zip(melodies, rhythms, bridges):
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        "melody": tf.train.Feature(int64_list=tf.train.Int64List(value=melody)),
                        "rhythm": tf.train.Feature(int64_list=tf.train.Int64List(value=rhythm)),
                        "bridge": tf.train.Feature(int64_list=tf.train.Int64List(value=bridge))
                    }
                )
            )
            file_writer.write(example.SerializeToString())

        file_writer.close()
    

feature_description = {
    'melody': tf.io.VarLenFeature(tf.int64),
    'rhythm': tf.io.VarLenFeature(tf.int64),
    'bridge': tf.io.VarLenFeature(tf.int64)
}

def _parse_function(example_proto):
    return tf.io.parse_single_example(example_proto, feature_description)

raw_dataset = tf.data.TFRecordDataset(DATASET_PATH)
dataset = raw_dataset.map(_parse_function).batch(BATCH_SIZE)
event_processor = MidiEventProcessor()
piano = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')

# Reconstruct MIDI data from TFRecord
#for i, data in enumerate(dataset.take(3)):
#    print('----------------------------------------------------------')
#    full_midi = pretty_midi.PrettyMIDI()
    
#    melody_instr = pretty_midi.Instrument(program=piano)
#    rhythm_instr = pretty_midi.Instrument(program=piano)
#    bridge_instr = pretty_midi.Instrument(program=piano)
    
#    print('Melody:')
#    melody_events = tf.sparse.to_dense(data['melody']).numpy()
#    print(melody_events)
#    for note in event_processor.decode(melody_events):
#        melody_instr.notes.append(note)
    
#    print('Rhythm:')
#    rhythm_events = tf.sparse.to_dense(data['rhythm']).numpy()
#    print(rhythm_events)
#    for note in event_processor.decode(rhythm_events):
#        rhythm_instr.notes.append(note)
    
#    print('Bridge:')
#    bridge_events = tf.sparse.to_dense(data['bridge']).numpy()
#    print(bridge_events)
#    for note in event_processor.decode(bridge_events):
#        bridge_instr.notes.append(note)
    
#    full_midi.instruments.append(melody_instr)
#    full_midi.instruments.append(rhythm_instr)
#    full_midi.instruments.append(bridge_instr)
#    IPython.display.display(IPython.display.Audio(full_midi.fluidsynth(), rate=44100))
#    filename = 'test_'+str(i)+'.mid'
#    full_midi.write(filename)
    
#    full_midi_ns = note_seq.midi_io.midi_file_to_note_sequence(filename)
#    note_seq.plot_sequence(full_midi_ns)
#    break

(909,)


## Train Transformer XL Model

In [4]:
midi_transformer = MIDITransformer(MODEL_CONFIG_PATH, MODEL_SAVE_PATH)
midi_transformer.reset_states()
print(midi_transformer)
print(dataset)
midi_transformer.train(dataset)

Restored Encoder from /home/richhiey/Desktop/workspace/projects/virtual_musicians/cache/checkpoints/model3/encoder/ckpt/ckpt-4
Restored Decoder from /home/richhiey/Desktop/workspace/projects/virtual_musicians/cache/checkpoints/model3/decoder/ckpt/ckpt-4
<models.midi_transformer.MIDITransformer object at 0x7fb0402e0d00>
<BatchDataset shapes: {bridge: (None, None), melody: (None, None), rhythm: (None, None)}, types: {bridge: tf.int64, melody: tf.int64, rhythm: tf.int64}>
Encoder part
-------------------------------------------------------------------
Decoder part
-------------------------------------------------------------------
tf.Tensor(5.963599, shape=(), dtype=float32)
Saved checkpoint for step 6: /home/richhiey/Desktop/workspace/projects/virtual_musicians/cache/checkpoints/model3/encoder/ckpt/ckpt-5, /home/richhiey/Desktop/workspace/projects/virtual_musicians/cache/checkpoints/model3/decoder/ckpt/ckpt-5
Encoder part
------------------------------------------------------------------

Decoder part
-------------------------------------------------------------------
tf.Tensor(5.963374, shape=(), dtype=float32)
Encoder part
-------------------------------------------------------------------
Decoder part
-------------------------------------------------------------------
tf.Tensor(5.9633574, shape=(), dtype=float32)
Encoder part
-------------------------------------------------------------------
Decoder part
-------------------------------------------------------------------
tf.Tensor(5.963005, shape=(), dtype=float32)
Encoder part
-------------------------------------------------------------------
Decoder part
-------------------------------------------------------------------
tf.Tensor(5.963133, shape=(), dtype=float32)
Encoder part
-------------------------------------------------------------------
Decoder part
-------------------------------------------------------------------


tf.Tensor(5.962889, shape=(), dtype=float32)
Encoder part
-------------------------------------------------------------------
Decoder part
-------------------------------------------------------------------
tf.Tensor(5.9631305, shape=(), dtype=float32)
Encoder part
-------------------------------------------------------------------
Decoder part
-------------------------------------------------------------------
tf.Tensor(5.962873, shape=(), dtype=float32)
Encoder part
-------------------------------------------------------------------
Decoder part
-------------------------------------------------------------------


KeyboardInterrupt: 

## Create new Note Sequences with the MIDI Transformer

In [None]:
#midi_transformer.predict(midi_note_sequence)