In [1]:
import tensorflow as tf 
from tensorflow import keras
import numpy as np
import os
import time
from pathlib import Path
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt
import math
import glob
import pandas as pd
import numpy as np
import mido
import pygame.midi
from mido import MidiFile, MidiTrack, Message

pygame 2.6.1 (SDL 2.28.4, Python 3.10.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
DOWNLOAD_ROOT = "https://github.com/ageron/handson-ml2/raw/master/datasets/jsb_chorales/"
FILENAME = "jsb_chorales.tgz"
filepath = keras.utils.get_file(FILENAME,
                                DOWNLOAD_ROOT + FILENAME,
                                cache_subdir="datasets/jsb_chorales",
                                extract=True)

In [3]:
jsb_chorales_dir = Path(filepath).parent
train_files = sorted(jsb_chorales_dir.glob("train/chorale_*.csv"))
valid_files = sorted(jsb_chorales_dir.glob("valid/chorale_*.csv"))
test_files = sorted(jsb_chorales_dir.glob("test/chorale_*.csv"))

In [4]:
def load_chorales(filepaths):
    return [pd.read_csv(filepath).values.tolist() for filepath in filepaths]

train_chorales = load_chorales(train_files)
valid_chorales = load_chorales(valid_files)
test_chorales = load_chorales(test_files)

notes = set()
for chorales in (train_chorales, valid_chorales, test_chorales):
    for chorale in chorales:
        for chord in chorale:
            notes |= set(chord)

n_notes = len(notes)
min_note = min(notes - {0})
max_note = max(notes)

assert min_note == 36
assert max_note == 81

In [5]:
def play_chorale(chorale_array, tempo=100000):
    """
    Play a Bach chorale using the given array of notes.
    
    Parameters:
    chorale_array (numpy array): A 2D array where each row is a time step and each column is a note index.
    tempo (int): Tempo of the MIDI playback, default is 500000 (microseconds per beat).
    """

    # Initialize the MIDI file
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)

    # Set the tempo (microseconds per beat)
    track.append(mido.MetaMessage('set_tempo', tempo=tempo))

    # MIDI note on/off settings
    note_on = 144  # MIDI message for note on
    note_off = 128  # MIDI message for note off

    for time_step in chorale_array:
        # For each time step (row), send note-on messages for the active notes
        for note in time_step:
            if note != 0:  # If the note is not 0 (0 means no note is played)
                track.append(Message('note_on', note=note, velocity=64, time=0))
        
        # Duration of each time step (you can adjust this)
        time_per_step = 480  # time per step in MIDI ticks (adjustable)
        
        # Send note-off messages after the duration
        for note in time_step:
            if note != 0:
                track.append(Message('note_off', note=note, velocity=64, time=time_per_step))
    
    # Save the generated MIDI to a file
    midi_filename = "bach_chorale.mid"
    mid.save(midi_filename)
    print(f"Chorale saved as {midi_filename}")
    
    # Initialize pygame for MIDI playback
    pygame.midi.init()

    # Manually select Microsoft GS Wavetable Synth (device 1)
    output_device_id = 1  # Based on your device listing, use Device 1
    
    if output_device_id >= 0:
        player = pygame.midi.Output(output_device_id)
    else:
        print("No valid MIDI output device found.")
        return

    try:
        # Set instrument to Acoustic Grand Piano (General MIDI instrument 0)
        player.set_instrument(0)
        
        # Parse the MIDI file using mido and send the MIDI messages to the player
        for msg in mido.MidiFile(midi_filename).play():
            if not msg.is_meta:
                if msg.type == 'note_on':
                    player.note_on(msg.note, msg.velocity)
                elif msg.type == 'note_off':
                    player.note_off(msg.note, msg.velocity)
        
        print("Playing chorale...")
    except Exception as e:
        print(f"Error playing MIDI: {e}")
    finally:
        player.close()
        pygame.midi.quit()

In [12]:
def create_target(batch):
    X = batch[:, :-1]
    Y = batch[:, 1:]
    return X, Y

def preprocess(window):
    window = tf.where(window == 0, window, window - min_note + 1)
    return tf.reshape(window, [-1])

def load_chorales_dataset(files, batch_size=16, shuffle_buffer_size=None, 
                 window_size=32, window_shift=8, cache=True):
    
    def batch_window(window):
        return window.batch(window_size + 1)

    def to_windows(chorale):
        dataset = tf.data.Dataset.from_tensor_slices(chorale)
        dataset = dataset.window(window_size + 1, window_shift, drop_remainder=True)
        return dataset.flat_map(batch_window)
    
    csv_files = glob.glob(files + '/*.csv')
    chorales_list = [pd.read_csv(f, header=0) for f in csv_files]
    chorales_list = [np.array(chorale) for chorale in chorales_list]

    chorales_dataset = tf.ragged.constant(chorales_list, ragged_rank=1)
    chorales_dataset = tf.data.Dataset.from_tensor_slices(chorales_dataset)
    chorales_dataset = chorales_dataset.flat_map(to_windows)
    chorales_dataset = chorales_dataset.map(preprocess)
    
    if cache:
        chorales_dataset = chorales_dataset.cache()
    if shuffle_buffer_size:
        chorales_dataset = chorales_dataset.shuffle(shuffle_buffer_size)
    chorales_dataset = chorales_dataset.batch(batch_size)
    chorales_dataset = chorales_dataset.map(create_target)

    return chorales_dataset.prefetch(tf.data.AUTOTUNE)

train = load_chorales_dataset('jsb_chorales/train')
valid = load_chorales_dataset('jsb_chorales/valid')
test = load_chorales_dataset('jsb_chorales/test')

print(train.element_spec)
print(valid.element_spec)
print(test.element_spec)

for X_batch, Y_batch in train.take(1):
    print(X_batch.shape, Y_batch.shape)
    print(X_batch[0])
    print(Y_batch[0])



(TensorSpec(shape=(None, None), dtype=tf.int32, name=None), TensorSpec(shape=(None, None), dtype=tf.int32, name=None))
(TensorSpec(shape=(None, None), dtype=tf.int32, name=None), TensorSpec(shape=(None, None), dtype=tf.int32, name=None))
(TensorSpec(shape=(None, None), dtype=tf.int32, name=None), TensorSpec(shape=(None, None), dtype=tf.int32, name=None))
(16, 131) (16, 131)
tf.Tensor(
[39 35 30 23 39 35 30 23 39 35 30 23 39 35 30 23 40 35 23 20 40 35 23 20
 40 35 25 20 40 35 25 20 42 34 27 15 42 34 27 15 42 34 27 15 42 34 27 15
 42 35 27 20 42 35 27 20 42 34 27 20 42 34 27 20 40 32 28 13 40 32 28 13
 40 34 28 13 40 34 28 13 39 35 30 11 39 35 30 11 39 35 30 11 39 35 30 11
 37 34 30 18 37 34 30 18 37 34 30 18 37 34 30 18 37 34 30 18 37 34 30 18
 37 34 30 18 37 34 30 18 39 35 30], shape=(131,), dtype=int32)
tf.Tensor(
[35 30 23 39 35 30 23 39 35 30 23 39 35 30 23 40 35 23 20 40 35 23 20 40
 35 25 20 40 35 25 20 42 34 27 15 42 34 27 15 42 34 27 15 42 34 27 15 42
 35 27 20 42 35 27 20 42 34

In [39]:
n_embedding_dims = 10

model = keras.models.Sequential([
    keras.layers.Embedding(input_dim=n_notes, output_dim=n_embedding_dims,
                           input_shape=[None]),
    keras.layers.Conv1D(32, kernel_size=2, padding="causal", activation="relu"),
    keras.layers.BatchNormalization(),
    keras.layers.Conv1D(48, kernel_size=2, padding="causal", activation="relu", dilation_rate=2),
    keras.layers.BatchNormalization(),
    keras.layers.Conv1D(64, kernel_size=2, padding="causal", activation="relu", dilation_rate=4),
    keras.layers.BatchNormalization(),
    keras.layers.Conv1D(96, kernel_size=2, padding="causal", activation="relu", dilation_rate=8),
    keras.layers.BatchNormalization(),
    keras.layers.GRU(256, return_sequences=True),
    keras.layers.Dense(n_notes, activation="softmax")
])

optimizer = keras.optimizers.Nadam(learning_rate=0.01)
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

  super().__init__(**kwargs)


In [10]:
def get_run_logdir():
    root_logdir = os.path.join(os.curdir, "my_logs")
    run_id = time.strftime("run_%Y_%m_%d-%H_%M")
    return os.path.join(root_logdir, run_id)

In [40]:
early_stopping_cb = keras.callbacks.EarlyStopping(patience=4, restore_best_weights=True)
lr_scheduler = keras.callbacks.ReduceLROnPlateau(factor=0.75, patience=2)
tensorboard_cb = keras.callbacks.TensorBoard(get_run_logdir())

history = model.fit(train, epochs=20, validation_data=valid, callbacks=[early_stopping_cb, lr_scheduler, tensorboard_cb])

model.save('bach_chorales_model.keras')

Epoch 1/20


    380/Unknown [1m47s[0m 110ms/step - accuracy: 0.6228 - loss: 1.4756

  self.gen.throw(typ, value, traceback)


[1m380/380[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m53s[0m 126ms/step - accuracy: 0.6230 - loss: 1.4746 - val_accuracy: 0.7637 - val_loss: 0.8694 - learning_rate: 0.0100
Epoch 2/20
[1m380/380[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 154ms/step - accuracy: 0.7206 - loss: 1.0676 - val_accuracy: 0.1080 - val_loss: 4.1395 - learning_rate: 0.0100
Epoch 3/20
[1m380/380[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 212ms/step - accuracy: 0.1813 - loss: 3.3299 - val_accuracy: 0.1974 - val_loss: 3.1370 - learning_rate: 0.0100
Epoch 4/20
[1m380/380[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 186ms/step - accuracy: 0.2226 - loss: 3.0002 - val_accuracy: 0.2237 - val_loss: 2.9716 - learning_rate: 0.0075
Epoch 5/20
[1m380/380[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 154ms/step - accuracy: 0.2311 - loss: 2.9473 - val_accuracy: 0.2339 - val_loss: 2.9286 - learning_rate: 0.0075


In [22]:
def postprocess(sequence):
    sequence = tf.where(sequence == 0, 0, sequence + min_note - 1)
    sequence = tf.cast(sequence, tf.int32)
    return tf.reshape(sequence, [-1, 4]).numpy()

def generate_chorale(model, seed_chords, length):
    arpegio = preprocess(tf.constant(seed_chords, dtype=tf.int64))
    arpegio = tf.reshape(arpegio, [1, -1])
    for chord in range(length):
        for note in range(4):
            #next_note = model.predict_classes(arpegio)[:1, -1:]
            next_note = np.argmax(model.predict(arpegio), axis=-1)[:1, -1:]
            arpegio = tf.concat([arpegio, next_note], axis=1)
    arpegio = tf.where(arpegio == 0, arpegio, arpegio + min_note - 1)
    return tf.reshape(arpegio, shape=[-1, 4])

def generate_chorale_v2(model, seed_chords, length, temperature=1):
    arpegio = preprocess(tf.constant(seed_chords, dtype=tf.int64))
    arpegio = tf.reshape(arpegio, [1, -1])
    for chord in range(length):
        for note in range(4):
            next_note_probas = model.predict(arpegio, verbose=False)[0, -1:]
            rescaled_logits = tf.math.log(next_note_probas) / temperature
            next_note = tf.random.categorical(rescaled_logits, num_samples=1)
            arpegio = tf.concat([arpegio, next_note], axis=1)
    arpegio = tf.where(arpegio == 0, arpegio, arpegio + min_note - 1)
    return tf.reshape(arpegio, shape=[-1, 4]).numpy()

In [36]:
model = keras.models.load_model('bach_chorales_model.keras')

seed_chords = test_chorales[0][:12]

print(seed_chords)


generated_chorale = generate_chorale_v2(model, seed_chords, 64, 0.8)
# generated_chorale = postprocess(generated_chorale)

for chord in generated_chorale:
    print(chord)

# Play the processed chorale (assuming play_chorale is already defined)
play_chorale(generated_chorale)

[[65, 60, 57, 53], [65, 60, 57, 53], [65, 60, 57, 53], [65, 60, 57, 53], [72, 60, 55, 52], [72, 60, 55, 52], [70, 60, 55, 52], [70, 60, 55, 52], [69, 60, 53, 53], [69, 60, 53, 53], [67, 60, 55, 52], [67, 60, 55, 52]]
[65 60 57 53]
[65 60 57 53]
[65 60 57 53]
[65 60 57 53]
[72 60 55 52]
[72 60 55 52]
[70 60 55 52]
[70 60 55 52]
[69 60 53 53]
[69 60 53 53]
[67 60 55 52]
[67 60 55 52]
[67 60 57 52]
[67 60 57 52]
[69 60 53 53]
[69 60 53 53]
[69 60 53 53]
[69 60 53 53]
[67 60 55 48]
[67 60 55 48]
[67 60 55 48]
[67 60 55 48]
[65 60 57 45]
[65 60 57 45]
[65 60 57 45]
[65 60 57 45]
[65 62 58 46]
[65 62 58 46]
[65 62 58 46]
[65 62 58 46]
[67 62 58 46]
[67 62 58 46]
[67 62 58 46]
[67 62 58 46]
[69 63 57 45]
[69 63 57 45]
[69 63 57 45]
[69 63 57 45]
[69 62 57 50]
[69 62 57 50]
[69 62 57 50]
[69 62 57 50]
[67 62 55 50]
[67 62 55 50]
[67 62 55 50]
[67 62 55 50]
[69 64 55 48]
[69 64 55 48]
[69 64 55 48]
[69 64 55 48]
[69 62 54 50]
[69 62 54 50]
[69 62 54 50]
[69 62 54 50]
[69 62 54 50]
[69 62 54 50]