# Imports

In [57]:
import tensorflow as tf
import keras_tuner as kt

from midiutil import MIDIFile
from io import BytesIO
import pygame
import pygame.mixer
from time import sleep

from pathlib import Path

# Load the Data

In [2]:
filepath = tf.keras.utils.get_file(
    origin="https://homl.info/bach",
    cache_dir=".",
    extract=True
)

In [73]:
MIN_VAL = 36 # Smallest note value

In [47]:
def bach_dataset(
    dataset_type: str,
    window_size: int=16,
    window_shift: int=1,
    cache: bool=False,
    shuffle_buffer_size: int=None,
    batch_size: int=32,
    seed: int=42
) -> tf.data.Dataset:
    root_dir = Path("./datasets/jsb_chorales")
    filepaths = sorted([str(path) for path in (root_dir / dataset_type).glob("chorale_*.csv")])
    min_val = 36 # smallest chord value

    def read_file(chorale_file_path: str) -> tf.data.Dataset:
        types = [int(), int(), int(), int()]
        return tf.data.experimental.CsvDataset(chorale_file_path, record_defaults=types, header=True)
    
    def group_notes(*notes: tf.Tensor) -> tf.Tensor:
        return tf.stack(notes, axis=-1)
    
    def create_arpegio(chord_batch: tf.Tensor) -> tf.Tensor:
        # First, rescale notes
        chord_batch = tf.where(chord_batch == 0, chord_batch, chord_batch - min_val + 1)
        arpegio = tf.reshape(chord_batch, [-1])
        return arpegio

    dataset_list = []
    for chorale_file in filepaths:
        ds = read_file(chorale_file)
        ds = ds.map(group_notes)
        ds = ds.window(size=window_size+1, shift=window_shift, drop_remainder=True)
        ds = ds.flat_map(lambda window_ds: window_ds.batch(window_size+1))
        ds = ds.map(create_arpegio)

        dataset_list.append(ds)
    
    dataset = tf.data.Dataset.from_tensor_slices(dataset_list)
    dataset = dataset.interleave(
        lambda x: x,
        cycle_length=1,
        num_parallel_calls=tf.data.AUTOTUNE
    )

    if cache:
        dataset = dataset.cache()
    if shuffle_buffer_size is not None:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size, seed=seed)
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda S: (S[:, :-1], S[:, 1:]))

    return dataset


In [48]:
train_ds = bach_dataset("train", cache=True, shuffle_buffer_size=1000, seed=42)
valid_ds = bach_dataset("valid", cache=True)
test_ds = bach_dataset("test")

# Making the Model

In [54]:
def build_model(
        n_conv_layers:int=1, n_starting_filters:int=32, max_dilations:int=4,
        n_recurrent_layers:int=1
):
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=[None]),
        tf.keras.layers.Embedding(input_dim=47, output_dim=5)
    ])

    # Convolutional Layers
    n_filters = n_starting_filters
    for conv_idx in range(n_conv_layers):
        dilation_rate = 2**(conv_idx % max_dilations)
        # n_filters = (2**conv_idx) * n_starting_filters
        model.add(
            tf.keras.layers.Conv1D(
                filters=n_starting_filters,
                kernel_size=2, padding="causal",
                activation="relu", dilation_rate=dilation_rate
            )
        )
        model.add(tf.keras.layers.BatchNormalization())

    # Recurrent layers
    for rec_idx in range(n_recurrent_layers):
        model.add(tf.keras.layers.GRU(n_starting_filters, return_sequences=True))

    # Output layer
    model.add(tf.keras.layers.Dense(47, activation="softmax"))

    return model

In [55]:
model = build_model(n_conv_layers=4, n_recurrent_layers=1)

optimizer = tf.keras.optimizers.legacy.Nadam(learning_rate=1e-3)

model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics="accuracy"
)

model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_2 (Embedding)     (None, None, 5)           235       
                                                                 
 conv1d_8 (Conv1D)           (None, None, 32)          352       
                                                                 
 batch_normalization_8 (Bat  (None, None, 32)          128       
 chNormalization)                                                
                                                                 
 conv1d_9 (Conv1D)           (None, None, 32)          2080      
                                                                 
 batch_normalization_9 (Bat  (None, None, 32)          128       
 chNormalization)                                                
                                                                 
 conv1d_10 (Conv1D)          (None, None, 32)         

In [56]:
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)

history = model.fit(
    train_ds, epochs=20,
    validation_data=valid_ds,
    callbacks=[early_stopping_cb]
)

Epoch 1/20


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20


In [58]:
def build_hyper_model(hp: kt.HyperParameters):
    n_conv_layers = hp.Int("n_conv_layers", min_value=1, max_value=5)
    n_recurrent_layers = hp.Int("n_recurrent_layers", min_value=1, max_value=2)
    
    learning_rate = hp.Float("learning_rate", min_value=1e-7, max_value=1.0, sampling="log")

    optimizer = tf.keras.optimizers.legacy.Nadam(learning_rate=learning_rate)

    model = build_model(
        n_conv_layers=n_conv_layers, n_starting_filters=32,
        max_dilations=4, n_recurrent_layers=n_recurrent_layers
    )

    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model

In [60]:
baysian_opt_tuner = kt.BayesianOptimization(
    build_hyper_model,
    objective="val_accuracy", seed=42,
    max_trials=10,
    overwrite=True, directory="my_baysian_choral", project_name="baysian_opt"
)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=3)

baysian_opt_tuner.search(
    train_ds.take(100), epochs=10,
    validation_data=valid_ds.take(100),
    callbacks=[early_stopping_cb]
)

Trial 10 Complete [00h 00m 23s]
val_accuracy: 0.7042490839958191

Best val_accuracy So Far: 0.7241138219833374
Total elapsed time: 00h 04m 12s


In [65]:
baysian_opt_tuner.get_best_hyperparameters(1)[0].values

{'n_conv_layers': 3,
 'n_recurrent_layers': 1,
 'learning_rate': 0.000303917691197622}

In [66]:
model = build_model(n_conv_layers=3, n_recurrent_layers=1)

learning_rate = baysian_opt_tuner.get_best_hyperparameters(1)[0].values["learning_rate"]
optimizer = tf.keras.optimizers.legacy.Nadam(learning_rate=learning_rate)
model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

In [67]:
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(self, iterations:int, max_lr:float=1e-3, start_lr:float=None, last_iterations:int=None, last_lr:float=None):
        self.iterations = iterations
        self.max_lr = max_lr
        self.start_lr = start_lr or max_lr / 10.0
        self.last_iterations = last_iterations or ((iterations // 10) + 1)
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_lr = last_lr or (self.start_lr / 1e3)
        self.iteration = 0
    
    def _interpolate(self, iter1:int, iter2:int, lr1:float, lr2:float) -> float:
        slope = (lr2 - lr1) / (iter2 - iter1)
        return slope * (self.iteration - iter1) + lr1

    def on_batch_begin(self, batch, logs=None):
        if self.iteration < self.half_iteration:
            lr = self._interpolate(0, self.half_iteration, self.start_lr, self.max_lr)
        elif self.iteration < 2 * self.half_iteration:
            lr = self._interpolate(self.half_iteration, 2*self.half_iteration, self.max_lr, self.start_lr)
        else:
            lr = self._interpolate(2*self.half_iteration, self.iterations, self.start_lr, self.last_lr)
        
        self.iteration += 1
        tf.keras.backend.set_value(self.model.optimizer.learning_rate, lr)

In [68]:
n_epochs = 10
oneCycle = OneCycleScheduler(
    iterations=1612*n_epochs,
    max_lr=learning_rate
)

history = model.fit(
    train_ds, epochs=n_epochs,
    validation_data=valid_ds,
    callbacks=[oneCycle]
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [69]:
model.save("my_bach_rnn", save_format="tf")

INFO:tensorflow:Assets written to: my_bach_rnn/assets


INFO:tensorflow:Assets written to: my_bach_rnn/assets


In [123]:
def revert_scaled_notes(scaled_notes: tf.Tensor):
    notes = tf.where(scaled_notes == 0, scaled_notes, scaled_notes + MIN_VAL - 1)
    return notes

In [124]:
def generate_chorale(model: tf.keras.Model, seed_notes: tf.Tensor, n_chords: int, temperature: float=1.0):
    N_NOTES_PER_CHORD = 4
    arpegio = tf.cast(seed_notes, tf.int64)
    arpegio = tf.reshape(arpegio, shape=[1, -1])
    for chord_idx in range(n_chords):
        for note_idx in range(N_NOTES_PER_CHORD):
            next_note_probas = model(arpegio, training=False)[0, -1:]
            rescaled_logits = tf.math.log(next_note_probas) / temperature
            next_note = tf.random.categorical(rescaled_logits, num_samples=1)
            arpegio = tf.concat([arpegio, next_note], axis=1)
    arpegio = revert_scaled_notes(arpegio)
    chorale = tf.reshape(arpegio, [-1, N_NOTES_PER_CHORD])
    return chorale

In [74]:
def scale_notes(notes: tf.Tensor):
    notes_scaled = tf.where(notes == 0, notes, notes - MIN_VAL + 1)
    return notes_scaled

In [120]:
def play_chorale(chorale: tf.Tensor):
    memFile = BytesIO()
    MyMIDI = MIDIFile(1)

    track = 0
    time = 0
    channel = 0
    duration = 1
    volume = 100
    MyMIDI.addTrackName(track,time,"Sample Track")
    MyMIDI.addTempo(track,time,240)

    # WRITE A SCALE
    for chord in chorale:
        for note in chord:
            MyMIDI.addNote(track, channel, note, time, duration, volume)
        time += duration
    MyMIDI.writeFile(memFile)

    # PLAYBACK
    pygame.init()
    pygame.mixer.init()
    memFile.seek(0)
    pygame.mixer.music.load(memFile)
    pygame.mixer.music.play()
    while pygame.mixer.music.get_busy():
        sleep(1)

In [137]:
for X, Y in test_ds.shuffle(100, seed=41).take(1):
    N_EXTRA_NOTES = 3 # There are always 3 more notes than an even multiple of 4 because of the windowing.
    seed_notes = X[0][:-N_EXTRA_NOTES]
    n_chords = 32
    new_chorale = generate_chorale(model, seed_notes, n_chords, temperature=.9)
    play_chorale(new_chorale)