In [1]:
#@title Get a smaller version of the Lakh MIDI Dataset v0.1
%%capture
!wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
!tar xvf clean_midi.tar.gz
!rm clean_midi.tar.gz
!rm -rf sample_data

dataset_path = "/content/clean_midi"

In [2]:
%%capture
!pip install pretty_midi
!pip install miditok
!pip install mido

In [3]:
#@title Installing libraries to hear a MIDI
%%capture
!apt-get update -qq && apt-get install -y fluidsynth
!pip install pretty_midi midi-clip

# GS2
!gdown 1wlpTIS70nQHMrYBjDT0M6nyg07kUejUv
!unzip GeneralUser_GS_v2.0.0--doc_r2.zip
!rm -rf GeneralUser_GS_v2.0.0--doc_r2.zip support documentation demo\ MIDIs
!mv GeneralUser\ GS\ v2.0.0.sf2 guGS.sf2

# PICONICA
!gdown 1uk51T9Gvo1n2JRl3_CHCg2FVGWiNI4qJ

# Utility library
!wget https://raw.githubusercontent.com/roostico/NesGen/refs/heads/main/utility.py

from utility import *

## Move files and rename them

In [4]:
from pathlib import Path
import os
import shutil
import random

# Paths to the files of the dataset
midi_paths = list(Path(dataset_path).resolve().glob("**/*.mid"))

midis_dir = "midis"
os.makedirs(midis_dir, exist_ok=True)

for i, midi_path in enumerate(midi_paths):
  new_midi_path = os.path.join(midis_dir, f"{i}.mid")
  shutil.move(str(midi_path), new_midi_path)

midis = list(Path("midis").resolve().glob("**/*.mid"))

def sample():
  return str(random.choice(midis))

## Select a sample of these files

In [5]:
import os
import random
import shutil

def copy_random_files(source_dir: str, dest_dir: str, num_files: int, is_file_valid) -> list:
    """Copies a specified number of random files from a source directory to a destination directory.

    Args:
        source_dir: The path to the source directory.
        dest_dir: The path to the destination directory.
        num_files: The number of files to move.
    """
    if not os.path.exists(source_dir):
        print(f"Error: Source directory '{source_dir}' not found.")
        return

    if os.path.exists(dest_dir):
      shutil.rmtree(dest_dir)

    os.makedirs(dest_dir, exist_ok=True)
    files = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]

    if len(files) < num_files:
        print(f"Warning: Only {len(files)} files found in '{source_dir}'. Moving all of them.")
        num_files = len(files)

    files_to_move = []
    i = 0
    while i < num_files:
        random_file = random.choice(files)
        if is_file_valid(random_file):
          files_to_move.append(random_file)
          i = i + 1
        else:
          continue

    result = []
    for file in files_to_move:
        source_path = os.path.join(source_dir, file)
        dest_path = os.path.join(dest_dir, file)
        shutil.copy(source_path, dest_path)
        result.append(dest_path)
    return result

def is_valid(file):
  return True

source_directory = "midis"
destination_directory = "selected"
number_of_files_to_move = 500

sample = copy_random_files(source_directory, destination_directory, number_of_files_to_move, is_valid)

## Pre-processing

In [6]:
import mido
import os
from tqdm import tqdm


def is_excluded(track_name):
    # Define criteria for excluding tracks
    exclude_keywords = ["drum", "effect", "percussion"]
    return any(keyword.lower() in track_name.lower() for keyword in exclude_keywords)

def merge_tracks_to_single_instrument(input_file, output_file, target_channel=0):
    # Load the MIDI file
    midi = mido.MidiFile(input_file)

    # List to hold all messages from all tracks with their absolute times
    all_messages = []
    tempo_messages = []

    # Iterate through each track
    for track in midi.tracks:
        # Get track name from meta messages, if available
        track_name = next((msg.name for msg in track if msg.type == 'track_name'), '')

        # Skip tracks that match the exclusion criteria
        if is_excluded(track_name):
            continue

        # Track the absolute time of messages in this track
        absolute_time = 0
        for msg in track:
            # Track absolute time and store messages
            absolute_time += msg.time

            # Collect tempo messages separately and set their absolute time to 0 (at the start)
            if msg.type == 'set_tempo':
                tempo_messages.append((0, msg))
            elif not msg.is_meta:
                # Convert the message to the target channel
                try:
                  msg = msg.copy(channel=target_channel)
                except:
                  channel = 0
                # Store non-meta messages with their absolute time
                all_messages.append((absolute_time, msg))

    # Sort all non-tempo messages by absolute time to merge them
    all_messages.sort(key=lambda x: x[0])

    # Create a new merged track
    merged_track = mido.MidiTrack()

    # Insert tempo messages at the beginning of the track
    for _, tempo_msg in tempo_messages:
        merged_track.append(tempo_msg)

    # Convert absolute times back to delta times
    last_time = 0
    for absolute_time, msg in all_messages:
        # Calculate delta time since the last message
        msg.time = absolute_time - last_time
        merged_track.append(msg)
        last_time = absolute_time

    # Create a new MIDI file with the merged track
    merged_midi = mido.MidiFile()
    merged_midi.tracks.append(merged_track)
    merged_midi.ticks_per_beat = midi.ticks_per_beat

    # Save the new MIDI file
    merged_midi.save(output_file)



if os.path.exists("pre-processed"):
  shutil.rmtree("pre-processed")

os.makedirs("pre-processed")

for file in tqdm(sample):
  try:
    merge_tracks_to_single_instrument(file, f"pre-processed/{os.path.basename(file)}")
  except Exception as e:
    print(f"There was an error: {e}")
    continue

processed = list(Path("pre-processed").resolve().glob("**/*.mid"))

 16%|█▌        | 81/500 [00:56<05:04,  1.38it/s]

There was an error: Message length 2093056 exceeds maximum length 1000000


 19%|█▉        | 94/500 [01:01<02:25,  2.79it/s]

There was an error: data byte must be in range 0..127


 56%|█████▌    | 279/500 [02:10<03:10,  1.16it/s]

There was an error: data byte must be in range 0..127


100%|██████████| 500/500 [03:32<00:00,  2.36it/s]


## Encoding

In [7]:
!rm -rf tokenized

In [8]:
from miditok import REMI

def midi_valid(midi) -> bool:
    if any(ts.numerator != 4 for ts in midi.time_signature_changes):
        return False  # time signature different from 4/*, 4 beats per bar
    return True

tokenizer = REMI()
tokenizer.tokenize_dataset(        # 2 velocity and 1 duration values
    Path("/content", "pre-processed"),
    Path("/content", "tokenized"),
    midi_valid,
)

Tokenizing music files (content/tokenized): 100%|██████████| 490/490 [01:14<00:00,  6.54it/s]


In [9]:
import json
from pathlib import Path

def read_json_files(json_file_paths):
    """Reads a list of JSON files and returns a list of objects.

    Args:
        json_file_paths: A list of file paths to JSON files.

    Returns:
        A list of objects, where each object represents the data from a JSON file.
        Returns an empty list if any error occurs during file processing.
    """
    objects = []
    for file_path in json_file_paths:
        try:
            with open(file_path, 'r') as f:
                try:
                    objects.append(json.load(f))
                except json.JSONDecodeError:
                    print(f"Error decoding JSON in file: {file_path}")
                    return [] # Return empty list on error
        except FileNotFoundError:
            print(f"Error: File not found - {file_path}")
            return [] # Return empty list on error
    return objects

# Example usage (assuming 'tokenized' directory contains JSON files):
tokenized_files = list(Path("/content", "tokenized").resolve().glob("**/*.json"))
data_objects = read_json_files(tokenized_files)

if data_objects:
    print(f"Successfully read {len(data_objects)} JSON files.")
    # Now you can work with the 'data_objects' list
    # For example, print the first object:
    # print(data_objects[0])
else:
    print("Error reading JSON files.")

Successfully read 490 JSON files.


In [10]:
encoded = [song["ids"][0] for song in data_objects]

## Creating a Tensorflow dataset with all IDs

In [11]:
import tensorflow as tf
import numpy as np

all_ids = np.concatenate(encoded)
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)

### Convert into sequences

In [12]:
seq_length = 400 #@param {type: 'slider', max: 500, min: 50, step: 50}

sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

## Preparing labels

In [13]:
def split_input_target(sequence):
    input_seq = sequence[:-1]
    target_seq = sequence[1:]
    return input_seq, target_seq

dataset = sequences.map(split_input_target)

### Creating training batches

In [14]:
# Batch size
BATCH_SIZE = 128 #@param {type: 'slider', max: 256, min: 32, step: 32}

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

dataset

<_PrefetchDataset element_spec=(TensorSpec(shape=(128, 400), dtype=tf.int64, name=None), TensorSpec(shape=(128, 400), dtype=tf.int64, name=None))>

## Splitting in Train, Validation and Test

In [15]:
ds_size = dataset.cardinality().numpy()
train_size = int(0.8 * ds_size)
val_size = int(0.1 * ds_size)
test_size = int(0.1 * ds_size)

In [16]:
train_ds = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid_ds = remaining.take(val_size)
test_ds = remaining.skip(val_size)

# Building the model

In [17]:
# Length of the vocabulary in StringLookup Layer
vocab_size = tokenizer.vocab_size

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024

## RNN creation

In [None]:
def create_rnn(seq_len, vocab_size, embedding_dim, rnn_units):
  inputs = tf.keras.Input(shape=(seq_len,))
  x = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs, training=True)
  x = tf.keras.layers.GRU(rnn_units, return_sequences=True)(x)
  x = tf.keras.layers.Dense(vocab_size)(x)
  model = tf.keras.Model(inputs, x)
  return model


In [None]:
model = create_rnn(seq_length, vocab_size, embedding_dim, rnn_units)
model.summary()

## Transformer

In [18]:
!pip install keras-nlp

Collecting keras-nlp
  Downloading keras_nlp-0.17.0-py3-none-any.whl.metadata (1.2 kB)
Collecting keras-hub==0.17.0 (from keras-nlp)
  Downloading keras_hub-0.17.0-py3-none-any.whl.metadata (7.4 kB)
Collecting tensorflow-text (from keras-hub==0.17.0->keras-nlp)
  Downloading tensorflow_text-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting tensorflow<2.19,>=2.18.0 (from tensorflow-text->keras-hub==0.17.0->keras-nlp)
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow<2.19,>=2.18.0->tensorflow-text->keras-hub==0.17.0->keras-nlp)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting keras>=3.5.0 (from tensorflow<2.19,>=2.18.0->tensorflow-text->keras-hub==0.17.0->keras-nlp)
  Downloading keras-3.6.0-py3-none-any.whl.metadata (5.8 kB)
Downloading keras_nlp-0.17.0-py3-none-any.whl (2.0 kB)
Downloading keras_hub-0.1

In [19]:
#todo
from keras import layers
from keras_nlp import nlp_layers

def build_transformer_classifier(seq_len,vocab_len):
  input=layers.Input(shape=(seq_len))
  pos_enc = nlp_layers.SinePositionEncoding()(input)
  add = layers.Add()([input, pos_enc])
  transf_enc=nlp_layers.TransformerEncoder(intermediate_dim=128,num_heads=2)(add)
  gl_avg_pool=layers.GlobalAveragePooling1D()(transf_enc)
  dense=layers.Dense(vocab_len*2, activation='relu')(gl_avg_pool)
  drop=layers.Dropout(rate=0.1)(dense)
  out=layers.Dense(vocab_len, activation='sigmoid')(drop)

  model = keras.models.Model(inputs=input, outputs=out)

  return model

ImportError: cannot import name 'nlp_layers' from 'keras_nlp' (/usr/local/lib/python3.10/dist-packages/keras_nlp/__init__.py)

In [None]:
transformer_encoder_classifier=build_transformer_classifier(seq_length,vocab_size)

In [None]:
transformer_encoder_classifier.summary()

In [None]:
keras.utils.plot_model(transformer_encoder_classifier,show_shapes=True,show_layer_names=False,dpi=80)

In [None]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
transformer_encoder_classifier.compile(optimizer='adam',loss=loss, metrics=['accuracy'])

In [None]:
epoch_count = 20
batch_size = BATCH_SIZE
patience=5

# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True
)

history = transformer_encoder_classifier.fit(train_ds,validation_data=valid_ds,epochs=epoch_count,batch_size=batch_size,callbacks=[checkpoint_callback,early_stopping])

## Try generating with un-trained model

In [None]:
from utility import playMidi

# Use the model
for input_example_batch, target_example_batch in dataset.take(1):
  example_batch_predictions = model(input_example_batch)
  print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

# Sample indices from predictions
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

# Decode
midi = encoding.encoder.decode([sampled_indices])
midi.dump_midi("example.mid")
playMidi("example.mid")

(64, 400, 282) # (batch_size, sequence_length, vocab_size)


In [None]:
# Calculate the loss
loss = tf.keras.losses.sparse_categorical_crossentropy(target_example_batch, example_batch_predictions, from_logits=True)
# Reduce mean to get a single scalar loss value
loss = tf.reduce_mean(loss)

print("Loss:", loss.numpy())

## Taken from tensorflow tutorial:

A newly initialized model shouldn't be too sure of itself, the output logits should all have similar magnitudes. To confirm this you can check that the exponential of the mean loss is approximately equal to the vocabulary size. A much higher loss means the model is sure of its wrong answers, and is badly initialized:

In [None]:
print("Checking if it is near to vocabulary size")
print(tf.exp(example_batch_mean_loss).numpy())
print("Vocab size: ", encoding.vocab_size)

Checking if it is near to vocabulary size
281.92435
Vocab size:  282


# Training

In [None]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])

In [None]:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True
)

In [None]:
EPOCHS = 20

history = model.fit(train_ds,
                    epochs=EPOCHS,
                    validation_data=valid_ds,
                    callbacks=[checkpoint_callback, early_stopping]
                    )

Epoch 1/20
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 602ms/step - accuracy: 0.1299 - loss: 4.3000 - val_accuracy: 0.2574 - val_loss: 2.8605
Epoch 2/20
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 747ms/step - accuracy: 0.2767 - loss: 2.7596 - val_accuracy: 0.3228 - val_loss: 2.5034
Epoch 3/20
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 663ms/step - accuracy: 0.3374 - loss: 2.4159 - val_accuracy: 0.3707 - val_loss: 2.2358
Epoch 4/20
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 640ms/step - accuracy: 0.3810 - loss: 2.1929 - val_accuracy: 0.3998 - val_loss: 2.1148
Epoch 5/20
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 662ms/step - accuracy: 0.4126 - loss: 2.0643 - val_accuracy: 0.4294 - val_loss: 1.9934
Epoch 6/20
[1m94/94[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 652ms/step - accuracy: 0.4410 - loss: 1.9599 - val_accuracy: 0.4529 - val_loss: 1.9080
Epoch 7/20
[1m 7/94[

# Generation

In [None]:
class OneStep(tf.keras.Model):
  def __init__(self, model, decoding, encoding, vocab_size, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.decode = decoding
    self.encode = encoding

    # Taken from tensorflow tutorial: useful to skip ids

    #skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    #sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
    #    values=[-float('inf')]*len(skip_ids),
    #    indices=skip_ids,
        # Match the shape to the vocabulary
    #    dense_shape=[vocab_size])
    #self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, input_ids):
    input_ids_ = tf.expand_dims(input_ids, axis=0)

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits = self.model(inputs=input_ids_)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature

    # Taken from tensorflow tutorial: apply prediction mask to prevent certain
    # ids from being generated
    #predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Return the ids and model state.
    return predicted_ids

In [None]:
one_step_model = OneStep(model, tokenizer.decode, tokenizer.encode, tokenizer.vocab_size)

In [None]:
def get_random_file() -> str:
  return random.choice(midis)

seed_ids = np.array(tokenizer.encode(get_random_file())[0].ids)
seed_ids = seed_ids[:seq_length]

In [None]:
import time

start = time.time()
next_ids = seed_ids
result = [next_ids]

for n in range(1000):
  next_ids = one_step_model.generate_one_step(next_ids)
  result.append(next_ids)

end = time.time()
print('\nRun time:', end - start, "\n", '_'*80, "\n")
result = np.concatenate(result[1:])
print("Shape of result: ", result.shape)


Run time: 2.448131799697876 
 ________________________________________________________________________________ 

Shape of result:  (1000,)


## Save the generator

In [None]:
tf.saved_model.save(one_step_model, 'one_step')

## Reload the generator

In [None]:
one_step_reloaded = tf.saved_model.load('one_step')

## Hear the result

In [None]:
from utility import playMidi

midi = encoding.encoder.decode([result])
midi.dump_midi("result.mid")
playMidi("result.mid")