In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import keras_nlp
import json
import string
import re
import music21
import os
import glob

In [None]:
NOTES_VOCAB_SIZE = 200
DURATION_VOCAB_SIZE = 200
MAX_LEN = 50
EMBEDDING_DIM = 128
KEY_DIM = 256
N_HEADS = 2
FEED_FORWARD_DIM = 256
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 128
EPOCHS = 5
GENERATE_LEN = 50

In [None]:
# dataset preparation
notes_list = []
duration_list = []
notes = []
durations = []
parser =  music21.converter
file_list = glob.glob("bach_cello_suite_data/*.mid")
for i, file in enumerate(file_list):
    print(i + 1, "Parsing %s" % file)
    score = parser.parse(file).chordify()

    notes.append("START")
    durations.append("0.0")

    for element in score.flat:
        note_name = None
        duration_name = None

        if isinstance(element, music21.key.Key):
            note_name = str(element.tonic.name) + ":" + str(element.mode)
            duration_name = "0.0"

        elif isinstance(element, music21.meter.TimeSignature):
            note_name = str(element.ratioString) + "TS"
            duration_name = "0.0"

        elif isinstance(element, music21.chord.Chord):
            note_name = element.pitches[-1].nameWithOctave
            duration_name = str(element.duration.quarterLength)

        elif isinstance(element, music21.note.Rest):
            note_name = str(element.name)
            duration_name = str(element.duration.quarterLength)

        elif isinstance(element, music21.note.Note):
            note_name = str(element.nameWithOctave)
            duration_name = str(element.duration.quarterLength)

        if note_name and duration_name:
            notes.append(note_name)
            durations.append(duration_name)


In [None]:
notes_list = []
duration_list = []

for i in range(len(notes) - MAX_LEN - 1):
    notes_list.append(" ".join(notes[i : (i + MAX_LEN + 1)]))
    duration_list.append(" ".join(durations[i : (i + MAX_LEN + 1)]))

In [None]:
print(notes_list[658])
print(duration_list[658])
print(len(notes))
print(len(durations))

In [None]:
# notes vectorization layer
notes_ds = tf.data.Dataset.from_tensor_slices(notes_list).batch(BATCH_SIZE, drop_remainder=True).shuffle(1000)
notes_vec_layer = layers.TextVectorization(
    output_mode='int',
    standardize=None
)
notes_vec_layer.adapt(notes_ds)
NOTES_VOCAB_SIZE = len(notes_vec_layer.get_vocabulary())

# duration vectorization layer
duration_ds = tf.data.Dataset.from_tensor_slices(duration_list).batch(BATCH_SIZE, drop_remainder=True).shuffle(1000)
duration_vec_layer = layers.TextVectorization(
    output_mode='int',
    standardize=None
)
duration_vec_layer.adapt(duration_ds)
DURATION_VOCAB_SIZE = len(duration_vec_layer.get_vocabulary())

# input_ds
input_ds = tf.data.Dataset.zip((notes_ds, duration_ds))

def prepare_dataset(notes, durations):
    notes = tf.expand_dims(notes, -1)
    durations = tf.expand_dims(durations,-1)
    tokenized_notes = notes_vec_layer(notes)
    tokenized_durations = duration_vec_layer(durations)
    x = (tokenized_notes[:, :-1], tokenized_durations[:, :-1])
    y = (tokenized_notes[:, 1:], tokenized_durations[:, 1:])
    return x, y

full_ds = input_ds.map(prepare_dataset)


In [None]:
for i in full_ds.take(1):
    print(i[0][0].shape, i[0][1].shape, i[1][0].shape, i[1][1].shape)

In [None]:
# cusal attention : 

def causal_attention_mask(batch_size, n_dest , n_src, dtype):
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1,1], dtype=tf.int32)], 0
    )
    return tf.tile(mask, mult)

np.transpose(causal_attention_mask(1, 10, 10, dtype=tf.int32)[0])

In [None]:
# transformer block 

class TransformerBlock(layers.Layer):
    def __init__(self, num_heads, key_dim, embed_dim, ff_dim, dropout_rate=0.1, name="transformer_block"):
        super(TransformerBlock, self).__init__(name=name)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate

        self.attn = layers.MultiHeadAttention(num_heads, key_dim, output_shape=embed_dim)
        self.dropout1 = layers.Dropout(self.dropout_rate)
        self.ln1 = layers.LayerNormalization(epsilon=1e-6)
        self.ffn1 = layers.Dense(self.ff_dim, activation='relu')
        self.ffn2 = layers.Dense(self.embed_dim)
        self.dropout2 = layers.Dropout(self.dropout_rate)
        self.ln2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(
            batch_size, seq_len, seq_len, tf.bool
        ) 
        attention_output, attention_scores = self.attn(
            inputs,
            inputs,
            attention_mask=causal_mask,
            return_attention_scores=True,
        )
        attention_output = self.dropout1(attention_output)
        out1 = self.ln1(inputs + attention_output)
        ffn1 = self.ffn1(out1)
        ffn2 = self.ffn2(ffn1)
        ffn_out = self.dropout2(ffn2)
        return ( self.ln2(out1 + ffn_out), attention_scores )
    
    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "key_dim": self.key_dim,
                "embed_dim": self.embed_dim,
                "num_heads": self.num_heads,
                "ff_dim": self.ff_dim,
                "dropout_rate": self.dropout_rate,
            }
        )
        return config

In [None]:
# token and position embedding

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

        self.token_embed = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.pos_embed = keras_nlp.layers.SinePositionEncoding()

    def call(self, x):
        embeddgins = self.token_embed(x)
        positions = self.pos_embed(embeddgins)
        return embeddgins + positions
    
    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config

In [None]:
# building model

note_inputs = layers.Input(shape=(None,), dtype=tf.int32)
duration_inputs = layers.Input(shape=(None,), dtype=tf.int32)
note_embed = TokenAndPositionEmbedding(NOTES_VOCAB_SIZE, EMBEDDING_DIM//2)(note_inputs)
duration_embed = TokenAndPositionEmbedding(DURATION_VOCAB_SIZE, EMBEDDING_DIM//2)(duration_inputs)
x = layers.Concatenate()([note_embed, duration_embed])
x, attention_scores = TransformerBlock(N_HEADS, KEY_DIM, EMBEDDING_DIM, FEED_FORWARD_DIM, name='attention')(x)

notes_outputs = layers.Dense(NOTES_VOCAB_SIZE, activation='softmax')(x)
duration_outputs = layers.Dense(DURATION_VOCAB_SIZE, activation='softmax')(x)


museNet = keras.Model(inputs=[note_inputs, duration_inputs], outputs=[notes_outputs, duration_outputs])
museNet.compile("adam", loss=[tf.keras.losses.SparseCategoricalCrossentropy(), tf.keras.losses.SparseCategoricalCrossentropy()])

In [None]:
museNet.summary()

In [None]:
if LOAD_MODEL:
    museNet = tf.keras.models.load_model("museNet", compile=True)

In [None]:
import pickle 
from fractions import Fraction
import time

class SaveModelCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        self.model.save('checkpoint/MuseNet.keras')

def get_midi_note(sample_note, sample_duration):
    new_note = None

    if "TS" in sample_note:
        new_note = music21.meter.TimeSignature(sample_note.split("TS")[0])

    elif "major" in sample_note or "minor" in sample_note:
        tonic, mode = sample_note.split(":")
        new_note = music21.key.Key(tonic, mode)

    elif sample_note == "rest":
        new_note = music21.note.Rest()
        new_note.duration = music21.duration.Duration(
            float(Fraction(sample_duration))
        )
        new_note.storedInstrument = music21.instrument.Violoncello()

    elif "." in sample_note:
        notes_in_chord = sample_note.split(".")
        chord_notes = []
        for current_note in notes_in_chord:
            n = music21.note.Note(current_note)
            n.duration = music21.duration.Duration(
                float(Fraction(sample_duration))
            )
            n.storedInstrument = music21.instrument.Violoncello()
            chord_notes.append(n)
        new_note = music21.chord.Chord(chord_notes)

    elif sample_note == "rest":
        new_note = music21.note.Rest()
        new_note.duration = music21.duration.Duration(
            float(Fraction(sample_duration))
        )
        new_note.storedInstrument = music21.instrument.Violoncello()

    elif sample_note != "START":
        new_note = music21.note.Note(sample_note)
        new_note.duration = music21.duration.Duration(
            float(Fraction(sample_duration))
        )
        new_note.storedInstrument = music21.instrument.Violoncello()

    return new_note


class MusicGenerator(tf.keras.callbacks.Callback):
    def __init__(self, index_to_note, index_to_duration, top_k=10):
        self.index_to_note = index_to_note
        self.note_to_index = {
            note: index for index, note in enumerate(index_to_note)
        }
        self.index_to_duration = index_to_duration
        self.duration_to_index = {
            duration: index for index, duration in enumerate(index_to_duration)
        }

    def sample_from(self, probs, temperature):
        probs = probs ** (1 / temperature)
        probs = probs / np.sum(probs)
        return np.random.choice(len(probs), p=probs), probs

    def get_note(self, notes, durations, temperature):
        sample_note_idx = 1
        while sample_note_idx == 1:
            sample_note_idx, note_probs = self.sample_from(
                notes[0][-1], temperature
            )
            sample_note = self.index_to_note[sample_note_idx]

        sample_duration_idx = 1
        while sample_duration_idx == 1:
            sample_duration_idx, duration_probs = self.sample_from(
                durations[0][-1], temperature
            )
            sample_duration = self.index_to_duration[sample_duration_idx]

        new_note = get_midi_note(sample_note, sample_duration)

        return (
            new_note,
            sample_note_idx,
            sample_note,
            note_probs,
            sample_duration_idx,
            sample_duration,
            duration_probs,
        )

    def generate(self, start_notes, start_durations, max_tokens, temperature):
        attention_model = keras.models.Model(
            inputs=self.model.input,
            outputs=self.model.get_layer("attention").output,
        )

        start_note_tokens = [self.note_to_index.get(x, 1) for x in start_notes]
        start_duration_tokens = [
            self.duration_to_index.get(x, 1) for x in start_durations
        ]
        sample_note = None
        sample_duration = None
        info = []
        midi_stream = music21.stream.Stream()

        midi_stream.append(music21.clef.BassClef())

        for sample_note, sample_duration in zip(start_notes, start_durations):
            new_note = get_midi_note(sample_note, sample_duration)
            if new_note is not None:
                midi_stream.append(new_note)

        while len(start_note_tokens) < max_tokens:
            x1 = np.array([start_note_tokens])
            x2 = np.array([start_duration_tokens])
            notes, durations = self.model.predict([x1, x2], verbose=0)

            repeat = True

            while repeat:
                (
                    new_note,
                    sample_note_idx,
                    sample_note,
                    note_probs,
                    sample_duration_idx,
                    sample_duration,
                    duration_probs,
                ) = self.get_note(notes, durations, temperature)

                if (
                    isinstance(new_note, music21.chord.Chord)
                    or isinstance(new_note, music21.note.Note)
                    or isinstance(new_note, music21.note.Rest)
                ) and sample_duration == "0.0":
                    repeat = True
                else:
                    repeat = False

            if new_note is not None:
                midi_stream.append(new_note)

            _, att = attention_model.predict([x1, x2], verbose=0)

            info.append(
                {
                    "prompt": [start_notes.copy(), start_durations.copy()],
                    "midi": midi_stream,
                    "chosen_note": (sample_note, sample_duration),
                    "note_probs": note_probs,
                    "duration_probs": duration_probs,
                    "atts": att[0, :, -1, :],
                }
            )
            start_note_tokens.append(sample_note_idx)
            start_duration_tokens.append(sample_duration_idx)
            start_notes.append(sample_note)
            start_durations.append(sample_duration)

            if sample_note == "START":
                break
    
        return info

    def on_epoch_end(self, epoch, logs=None):
        info = self.generate(
            ["START"], ["0.0"], max_tokens=200, temperature=0.7
        )
        midi_stream = info[-1]["midi"].chordify()
        print(info[-1]["prompt"])
        midi_stream.write(
            "midi",
            fp=os.path.join(
                "outputs/",
                "output-" + str(epoch).zfill(4) + ".mid",
            ),
        )


callbacks = [SaveModelCallback(), MusicGenerator(notes_vec_layer.get_vocabulary(), duration_vec_layer.get_vocabulary())]

# text_generator = TextGenerator(vocab)

In [None]:
museNet.fit(
    full_ds,
    epochs=2,
    callbacks=callbacks
)