# Getting the dataset

In [None]:
#@title Get the full version of the Lakh MIDI Dataset v0.1
!wget http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz
!tar xvf lmd_full.tar.gz
!rm lmd_full.tar.gz

dataset_path = "/content/lmd_full"

In [1]:
#@title Get a smaller version of the Lakh MIDI Dataset v0.1
%%capture
!wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
!tar xvf clean_midi.tar.gz
!rm clean_midi.tar.gz

dataset_path = "/content/clean_midi"

In [None]:
#@title Get the NESMDB dataset
!gdown 1gIli7G1wu0QWDLzRc-CPWB8C4Hu0XVn3
!unzip nesmdb_midi.zip
!rm nesmdb_midi.zip

# Tokenization

In [2]:
#@title Install libraries to manage MIDI files and their tokenization
%%capture
!pip install miditok
!pip install pretty_midi
!pip install --upgrade "transformers>=4.45"

# Utility functions

In [3]:
def normalize_to_range(arr, range_min=0, range_max=1):
    # Get the minimum and maximum of the array
    arr_min = np.min(arr)
    arr_max = np.max(arr)

    # Normalize to [0, 1]
    normalized_arr = (arr - arr_min) / (arr_max - arr_min)

    # Scale to [range_min, range_max] -> [-1, 1]
    scaled_arr = normalized_arr * (range_max - range_min) + range_min

    return scaled_arr, arr_min, arr_max

def de_normalize(arr, original_min, original_max, range_min=0, range_max=1):
    # Scale to [0, 1]
    scaled_arr = (arr - range_min) / (range_max - range_min)

    # Normalize to [original_min, original_max]
    de_normalized_arr = scaled_arr * (original_max - original_min) + original_min

    return de_normalized_arr

# Data preparation (using miditok)

In [4]:
from miditok import REMI
from pathlib import Path

# Paths to the files of the dataset
midi_paths = list(Path("clean_midi").resolve().glob("**/*.mid"))

## Option 1: use a pre-trained tokenizer

In [None]:
#@title Download tokenizer trained params
%%capture
!wget https://raw.githubusercontent.com/roostico/NesGen/refs/heads/main/tokenizer/clean_midi_remi_params.json

In [None]:
tokenizer = REMI(params="clean_midi_remi_params.json")

## Option 2: train the tokenizer

In [5]:
# Tokenizer
tokenizer = REMI()
# tokenizer.train(vocab_size=30000, files_paths=midi_paths)
# tokenizer.save("tokenizer.json")

## Setup `DatasetMIDI`

In [7]:
from random import shuffle
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training

# Split MIDI paths in train/valid/test sets
total_num_files = len(midi_paths)
num_files_valid = round(total_num_files * 0.15)
num_files_test = round(total_num_files * 0.15)
shuffle(midi_paths)
midi_paths_valid = midi_paths[:num_files_valid]
midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test]
midi_paths_train = midi_paths[num_files_valid + num_files_test:]

# Chunk MIDIs and perform data augmentation on each subset independently
for files_paths, subset_name in (
    (midi_paths_train, "train"), (midi_paths_valid, "valid"), (midi_paths_test, "test")
):

    # Split the MIDIs into chunks of sizes approximately about 1024 tokens
    subset_chunks_dir = Path(f"Lakh_{subset_name}")
    split_files_for_training(
        files_paths=files_paths,
        tokenizer=tokenizer,
        save_dir=subset_chunks_dir,
        max_seq_len=1024,
        num_overlap_bars=2,
    )

    # Perform data augmentation
    #augment_dataset(
    #    subset_chunks_dir,
    #    pitch_offsets=[-12, 12],
    #    velocity_offsets=[-4, 4],
    #    duration_offsets=[-0.5, 0.5],
    #)

# Create Dataset and Collator for training
midi_paths_train = list(Path("Lakh_train").glob("**/*.mid"))
midi_paths_valid = list(Path("Lakh_valid").glob("**/*.mid"))
midi_paths_test = list(Path("Lakh_test").glob("**/*.mid"))
kwargs_dataset = {"max_seq_len": 1024, "tokenizer": tokenizer, "bos_token_id": tokenizer["BOS_None"], "eos_token_id": tokenizer["EOS_None"]}
dataset_train = DatasetMIDI(midi_paths_train, **kwargs_dataset)
dataset_valid = DatasetMIDI(midi_paths_valid, **kwargs_dataset)
dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)

Splitting music files (Lakh_train):   4%|▍         | 474/12080 [00:15<06:20, 30.53it/s]


RuntimeError: File not found

# Data cleaning

## Show some information of the Dataset

In [None]:
tokens = tokenizer(Path("Lakh_train", "BAP", "Verdammt lang her_t0_0.mid"))
print((tokens))

In [None]:
print(tokenizer)
print(dataset_train[0])
#tokens = tokenizer(Path("Lakh_train", "Asia", "Don't Cry_t0_0.mid"))
#print(type(tokens))
#tokenizer(tokens).dump_midi("test.mid")

## Padding and cleaning of DatasetMIDI

In [None]:
import numpy as np
import torch
import tensorflow as tf

def torch_tensor_to_padded_numpy(tensor: torch.Tensor,
                                 padded_max_length: int) -> np.ndarray:
  array = tensor.numpy()
  if array[0] != tokenizer["BOS_None"]:
    array = np.insert(array, 0, tokenizer["BOS_None"])
  if array[-1] != tokenizer["EOS_None"]:
    array = np.append(array, tokenizer["EOS_None"])

  array = np.pad(array, (0, padded_max_length - len(array)), 'constant')
  return array


count = 0
train_x = []
val_x = []
test_x = []

max_len_train = max(len(arr["input_ids"]) for arr in dataset_train)
max_len_val = max(len(arr["input_ids"]) for arr in dataset_valid)
max_len_test = max(len(arr["input_ids"]) for arr in dataset_test)
max_len = max(max_len_train, max_len_val, max_len_test)
print("Max length of sequence in train_x is: " + str(max_len_train))
print("Max length of sequence in val_x is: " + str(max_len_val))
print("Max length of sequence in test_x is: " + str(max_len_test))

print("Using max_length: " + str(max_len))

for (result, input) in \
[(train_x, dataset_train), (val_x, dataset_valid), (test_x, dataset_test)]:
  for i in input:
    ids = i['input_ids']
    array = torch_tensor_to_padded_numpy(ids, max_len)
    result.append(array)

train_x = np.array(train_x)
print("Shape of train_x is " + str(train_x.shape))
val_x = np.array(val_x)
print("Shape of val_x is " + str(val_x.shape))
test_x = np.array(test_x)
print("Shape of test_x is " + str(test_x.shape))

## Saving the generated train, valid and test arrays (if necessary)

In [None]:
np.savetxt('train_x.txt', train_x, fmt='%d')
np.savetxt('val_x.txt', val_x, fmt='%d')
np.savetxt('test_x.txt', test_x, fmt='%d')

## Loading previous train, valid and test arrays (if necessary)

In [None]:
train_x = np.loadtxt('train_x.txt', dtype=int)
val_x = np.loadtxt('val_x.txt', dtype=int)
test_x = np.loadtxt('test_x.txt', dtype=int)

## Normalization...

In [None]:
assert not np.any(np.isnan(train_x))
assert not np.any(np.isnan(val_x))
assert not np.any(np.isnan(test_x))

normalized_train_x, original_min_train, original_max_train = normalize_to_range(train_x, 0, 1)
assert (np.max(normalized_train_x)) == 1
assert (np.min(normalized_train_x)) == 0
normalized_val_x, original_min_val, original_max_val = normalize_to_range(val_x, 0, 1)
assert (np.max(normalized_val_x)) == 1
assert (np.min(normalized_val_x)) == 0
normalized_test_x, original_min_test, original_max_test = normalize_to_range(test_x, 0, 1)
assert (np.max(normalized_test_x)) == 1
assert (np.min(normalized_test_x)) == 0


##... or if you want to skip normalization

In [None]:
normalized_train_x = train_x
normalized_val_x = val_x
normalized_test_x = test_x

## Preparing labels

In [None]:
normalized_train_y = np.roll(normalized_train_x, shift=-1, axis=1)
normalized_val_y = np.roll(normalized_val_x, shift=-1, axis=1)
normalized_test_y = np.roll(normalized_test_x, shift=-1, axis=1)

print("Shape of normalized_train_y is " + str(normalized_train_y.shape))
print("Shape of normalized_val_y is " + str(normalized_val_y.shape))
print("Shape of normalized_test_y is " + str(normalized_test_y.shape))

# Model creation

In [None]:
#@title Install `keras_nlp`
!pip install keras_nlp

## Creating a transformer

In [None]:
import keras_nlp.layers as nlp_layers

def create_transformer(vocab_size, seq_len, embedding_dim, num_heads, dff, num_layers):
  # Input
    inputs = tf.keras.Input(shape=(seq_len,))

    # Embedding
    embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs)

    # Encoder
    encoder = nlp_layers.TransformerEncoder(num_heads=num_heads, intermediate_dim=dff)(embedding)

    # Decoder
    decoder = nlp_layers.TransformerDecoder(num_heads=num_heads, intermediate_dim=dff)(embedding, encoder)

    # Output
    outputs = tf.keras.layers.Dense(vocab_size, activation='softmax')(decoder)

    # Crea il modello
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    return model, encoder, decoder

## Instantiate a transformer

In [None]:
vocab_size = tokenizer.vocab_size
seq_len = max_len

model, encoder, decoder = create_transformer(vocab_size=vocab_size,
                                             seq_len=seq_len,
                                             embedding_dim=256,
                                             num_heads=8,
                                             dff=1024,
                                             num_layers=6)

## Training the transformer

In [None]:
from tensorflow.keras.callbacks import EarlyStopping

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

model.fit(normalized_train_x, normalized_train_y,
          epochs=5,
          validation_data=(normalized_val_x, normalized_val_y),
          callbacks=[early_stopping],
          batch_size=32
          )

model.save("NesGen_v1.keras")

## Model loading (optional)

In [None]:
from tensorflow.keras.models import load_model

model = load_model("NesGen_v1.keras")

# Model usage

Here we use the trained model to generate new MIDI

## Utility functions

In [None]:
from tqdm import tqdm

def get_random_ids_from_dataset(dataset):
  return dataset[np.random.choice(dataset.shape[0])]

def sample_next_token(probabilities) -> int:
  # Ensure the probabilities sum to 1 (they should, but numerical issues might affect it)
  last_token_probs = probabilities[0, -1]
  last_token_probs /= last_token_probs.sum()
  return np.random.choice(len(last_token_probs), p=last_token_probs)

def next_token(model, seed_ids):
    probabilities = model.predict(seed_ids, verbose=0)
    next_token = sample_next_token(probabilities)
    return next_token

def generate_ids(model, seed_ids, eos_id, pad_id, bos_id, max_len=None, show_progress=True):
  if max_len is None:
    max_len = seed_ids.shape[1]
  seed = seed_ids
  generated_ids = []
  if not show_progress:
    iterations = range(max_len)
  else:
    iterations = tqdm(range(max_len))

  for _ in iterations:
    next_token_id = next_token(model, seed)
    generated_ids.append(next_token_id)
    if next_token_id == eos_id:
      break
    elif next_token_id == pad_id:
      continue

    seed = np.roll(seed, -1, axis=1)
    seed[0, -1] = next_token_id

  result = np.array(generated_ids)
  result[0] = bos_id
  result[-1] = eos_id
  return result


## Generate the IDs for the new sequence

In [None]:
seed = get_random_ids_from_dataset(normalized_test_y).reshape((1, max_len))
generated_ids = generate_ids(
    model,
    seed,
    eos_id=tokenizer["EOS_None"],
    pad_id=tokenizer["PAD_None"],
    bos_id=tokenizer["BOS_None"],
    max_len=100
)
print("\nGenerated\n" + str(generated_ids))

## Conversion from IDs to MIDI

### Utility function

In [None]:
import miditok

def ids_to_midi(
    ids: np.ndarray,
    tokenizer: miditok.tokenizations.remi.REMI,
    file_name: str ="result.mid",
    output_dir: str = "/content/"
  ):
  tokenizer([ids.astype(np.int32)]).dump_midi(Path(output_dir, file_name))

### Actual conversion

In [None]:
file_name = "result.mid"

print(f"Converting IDs to MIDI file: {file_name}...")

ids_to_midi(generated_ids, tokenizer, file_name=file_name)

print("DONE!")

### End to End utility

For generating multiple files in one call

In [None]:
def generate_midi(
    dataset,
    model,
    tokenizer,
    output_folder="/content/gen_midi",
    num_files=1,
    max_len=100
):
  if not os.path.exists(output_folder):
        os.makedirs(output_folder)
  for i in tqdm(range(num_files)):
    seed = get_random_ids_from_dataset(normalized_test_y).reshape((1, normalized_test_y.shape[1]))
    generated_ids = generate_ids(
        model,
        seed,
        eos_id=tokenizer["EOS_None"],
        pad_id=tokenizer["PAD_None"],
        bos_id=tokenizer["BOS_None"],
        max_len=max_len,
        show_progress=False
    )
    file_name = str(i)+".mid"/
    ids_to_midi(generated_ids, tokenizer, file_name=file_name, output_dir=output_folder)


#### Generate some files

In [None]:
n_files = 10
generate_midi(dataset_test,
               model,
               tokenizer,
               output_folder="/content/gen_midi/",
               num_files=n_files,
               max_len=100
               )

# 2nd PART - Playing the MIDI

## Utility functions

In [None]:
def random_file(root, keyword=None):
    import glob
    import os
    import random
    mid_files = glob.glob(os.path.join(root, "**", "*.mid"), recursive=True)
    if keyword is not None:
      mid_files = [file for file in mid_files if keyword in file.lower()]
    return random.choice(mid_files)

## Libraries

In [None]:
#@title Installing the required libraries
%%capture
!apt-get update -qq && apt-get install -y fluidsynth
!pip install pretty_midi midi-clip

In [None]:
#@title Download example Soundfonts (GeneralUser GS v2 and PICONICA)
%%capture
!gdown 1wlpTIS70nQHMrYBjDT0M6nyg07kUejUv
!unzip GeneralUser_GS_v2.0.0--doc_r2.zip
!rm -rf GeneralUser_GS_v2.0.0--doc_r2.zip support documentation demo\ MIDIs
!mv GeneralUser\ GS\ v2.0.0.sf2 guGS.sf2

# PICONICA
!gdown 1uk51T9Gvo1n2JRl3_CHCg2FVGWiNI4qJ

In [None]:
#@title Optional: download other soundfonts (pDPP)
%%capture
# Pokemon
!gdown 1vDK_xH7WeAqQrrBFXfh4Q205x6oNhTQt

## Utility function to generate the audio on Colab

### Taken from https://github.com/bzamecnik/midi2audio/blob/master/midi2audio.py

In [None]:
import argparse
import os
import subprocess

__all__ = ['FluidSynth']

DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_GAIN = 0.2

class FluidSynth():
    def __init__(self, sound_font=DEFAULT_SOUND_FONT, sample_rate=DEFAULT_SAMPLE_RATE, gain=DEFAULT_GAIN):
        self.sample_rate = sample_rate
        self.sound_font = os.path.expanduser(sound_font)
        self.gain = gain

    def midi_to_audio(self, midi_file: str, audio_file: str, verbose=True):
        if verbose:
            stdout = None
        else:
            stdout = subprocess.DEVNULL
        subprocess.call(
            ['fluidsynth', '-ni', '-g', str(self.gain), self.sound_font, midi_file, '-F', audio_file, '-r', str(self.sample_rate)],
            stdout=stdout,
        )

    def play_midi(self, midi_file):
        subprocess.call(['fluidsynth', '-i', '-g', str(self.gain), self.sound_font, midi_file, '-r', str(self.sample_rate)])

### Other utility functions

In [None]:
import pretty_midi
import os
import librosa.display
import matplotlib.pyplot as plt

def show_midi_info(midi_path, print_notes=False):
  midi_data = pretty_midi.PrettyMIDI(midi_path)
  print("Instruments: ", [instrument.name for instrument in midi_data.instruments])
  print("MIDI duration: {duration:.2f} seconds".format(duration=midi_data.get_end_time()))
  if print_notes:
    for instrument in midi_data.instruments:
      print(instrument.name)
      for note in instrument.notes:
        print(note.start, note.end, note.pitch, note.velocity)

def piano_roll(midi_path):
  plt.figure(figsize=(12, 4))
  plot_piano_roll(path, 24, 84)

def plot_piano_roll(path, start_pitch, end_pitch, fs=100):
    midi_data = pretty_midi.PrettyMIDI(path)
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(midi_data.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))

def change_midi_velocity(midi_path, output_path, delta=0): # Renamed the function to avoid name conflict
  midi_data = pretty_midi.PrettyMIDI(midi_path)
  for instrument in midi_data.instruments:
    for note in instrument.notes:
      note.velocity += delta
  midi_data.write(output_path)

def convert_midi_to_wav(soundfont_path, midi_path, output_path, gain=None, velocity_change=0): # Renamed the argument
  change_midi_velocity(midi_path, "temp.mid", delta=velocity_change) # Call the renamed function
  FluidSynth(soundfont_path, gain=gain).midi_to_audio("temp.mid", output_path)
  os.remove("temp.mid")


def trim_midi(midi_path, start, end):
  import mido
  import midi_clip
  mid = mido.MidiFile(midi_path)
  trimmed_midi = midi_clip.midi_clip(mid, start, end)

  dir_name, base_name = os.path.split(midi_path)
  new_base_name = "trimmed_" + base_name
  output_path = os.path.join(dir_name, new_base_name)
  trimmed_midi.save(output_path)
  return output_path

def playMidi(midi_file_path,
             soundfont_path="/content/guGS.sf2",
             output_path="audio.wav",
             start=None,
             end=None,
             gain=DEFAULT_GAIN,
             velocity_change=0
             ):
    from IPython.display import Audio

    if start is not None and end is not None:
      midi_file_path = trim_midi(midi_file_path, start, end)
      convert_midi_to_wav(soundfont_path, midi_file_path, output_path, gain=gain, velocity_change=velocity_change)
      os.remove(midi_file_path)
    else:
      convert_midi_to_wav(soundfont_path, midi_file_path, output_path, gain=gain, velocity_change=velocity_change)
    return Audio(output_path)

## Using generated MIDI

In [None]:
#@title Show info

generated_midi_path = "gen_midi/3.mid" # @param {type:"string"}

print("Midi info:")
show_midi_info(generated_midi_path)

In [None]:
#@title Play the MIDI

generated_midi_path = "gen_midi/4.mid" # @param {type:"string"}
soundfont = "PICONICA.sf2" # @param ["PICONICA.sf2", "guGS.sf2", "PokeDP.sf2"]

playMidi(generated_midi_path, soundfont_path=soundfont)

# Extras (playing other MIDIs)

## Play a random MIDI of the Lakh dataset

In [None]:
path = random_file(dataset_path)
print("Converting: " + path)
print("Midi info:")
show_midi_info(path)
print("Synthetized:")
playMidi(path)

## Play a random MIDI of the NESMDB dataset

In [None]:
path = random_file("nesmdb_midi")
print("Converting: " + path)
print("Midi info:")
show_midi_info(path)
print("Synthetized:")
playMidi(path, soundfont_path="PICONICA.sf2", velocity_change=30, gain=1)

In [8]:
import copy
import pretty_midi

class midi_to_tokens():
    def __init__(self, path, steps_per_beat=12):
        self.steps_per_beat = steps_per_beat
        self.pm = pretty_midi.PrettyMIDI(path)
        self.dbs = self.pm.get_downbeats().tolist() + [self.pm.get_end_time()] # dbs := downbeats
        self.tokens = self._tokenize()

    def __call__(self):
        return ' '.join(self.tokens)

    def _time_to_step(self, time):
        return round(self.pm.time_to_tick(time) / self.pm.resolution * self.steps_per_beat)

    def _event_to_tokens(self, event):
        if event in ('bar', 'beat'):
            return [event]
        elif isinstance(event, pretty_midi.containers.Note):
            return [f'note_{event.pitch}', f'len_{self._time_to_step(event.end) - self._time_to_step(event.start)}']

    def _trim_note(self, note, start, end):
        n = copy.copy(note)
        n.start, n.end = max(n.start, start), min(n.end, end)
        return n

    def _tokenize(self, start_measure=1, end_measure=None):
        start, end = self.dbs[start_measure - 1], self.dbs[end_measure or -1]

        notes = []
        for inst in self.pm.instruments:
            notes += inst.notes
        notes.sort(key=lambda x: (x.start, -x.pitch))

        events = []
        events += [(self._time_to_step(db), 'bar') for db in self.dbs if start <= db < end]
        events += [(self._time_to_step(b), 'beat') for b in set(self.pm.get_beats()) - set(self.dbs) if start <= b < end] # beats without downbeats
        events += [(self._time_to_step(max(n.start, start)), self._trim_note(n, start, end)) for n in notes if start <= n.start < end or start < n.end <= end]
        events.sort(key=lambda x: x[0])

        tokens = []
        last_beat = 0
        for step, event in events:
            if event in ('bar', 'beat'):
                last_beat = step
            if step - last_beat:
                tokens.append(f'pos_{step - last_beat}')
            tokens += self._event_to_tokens(event)

        return tokens

    def measures(self, start_measure=1, end_measure=None):
        return self._tokenize(start_measure, end_measure)

## Try Tokenization e Generation

In [64]:
str(midi_paths[8])

'/content/clean_midi/Genesis/Watcher of the Skies.mid'

In [72]:
tokens = midi_to_tokens(str(midi_paths[8]), steps_per_beat=12).tokens

In [73]:
import pretty_midi

class TokensToMidi:
    def __init__(self, tokens, steps_per_beat=12, ticks_per_beat=960, tempo=120):
        self.tokens = tokens
        self.steps_per_beat = steps_per_beat
        self.ticks_per_step = ticks_per_beat // steps_per_beat
        self.tempo = tempo
        self.ticks_per_beat = ticks_per_beat

    def _ticks_to_time(self, ticks):
        return ticks * 60 / (self.tempo * self.ticks_per_beat)

    def generate_midi(self):
        pm = pretty_midi.PrettyMIDI(initial_tempo=self.tempo)
        instrument = pretty_midi.Instrument(program=38)

        time = 0
        last_beat = 0

        i = 0
        while i < len(self.tokens):
            token = self.tokens[i]

            if token == "bar":
                time += self._ticks_to_time(self.ticks_per_step * self.steps_per_beat)
                last_beat = time
            elif token == "beat":
                time = last_beat
                last_beat = time ì
            elif token.startswith("pos_"):
                position = int(token.split("_")[1])
                time = last_beat + self._ticks_to_time(self.ticks_per_step * position)
            elif token.startswith("note_"):
                pitch = int(token.split("_")[1])
                length_token = self.tokens[i + 1]
                length = int(length_token.split("_")[1])
                duration = self._ticks_to_time(self.ticks_per_step * length)

                note = pretty_midi.Note(
                    velocity=100,
                    pitch=pitch,
                    start=time,
                    end=time + duration
                )
                instrument.notes.append(note)

                i += 1  ì
            i += 1

        pm.instruments.append(instrument)
        pm.write("reconstructed_output.mid")
        return pm

In [74]:
midi_reconstructor = TokensToMidi(tokens)
midi_reconstructor.generate_midi()

<pretty_midi.pretty_midi.PrettyMIDI at 0x7cff74d04160>