# Getting the dataset

In [None]:
#@title Get the full version of the Lakh MIDI Dataset v0.1
!wget http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz
!tar xvf lmd_full.tar.gz
!rm lmd_full.tar.gz

dataset_path = "/content/lmd_full"

In [1]:
#@title Get a smaller version of the Lakh MIDI Dataset v0.1
%%capture
!wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
!tar xvf clean_midi.tar.gz
!rm clean_midi.tar.gz

dataset_path = "/content/clean_midi"

In [None]:
#@title Get the NESMDB dataset
!gdown 1gIli7G1wu0QWDLzRc-CPWB8C4Hu0XVn3
!unzip nesmdb_midi.zip
!rm nesmdb_midi.zip

# Data preparation (using miditok)

In [2]:
%%capture
!pip install pretty_midi

In [3]:
from pathlib import Path

# Paths to the files of the dataset
midi_paths = list(Path("clean_midi").resolve().glob("**/*.mid"))

# Data cleaning

In [4]:
import os
import shutil

midis_dir = "midis"
os.makedirs(midis_dir, exist_ok=True)

for i, midi_path in enumerate(midi_paths):
  new_midi_path = os.path.join(midis_dir, f"{i}.mid")
  shutil.move(str(midi_path), new_midi_path)

midis = list(Path("midis").resolve().glob("**/*.mid"))

# Tokenization

## MIDI2Tokens

In [5]:
import copy
import pretty_midi

class midi_to_tokens():
    def __init__(self, path, steps_per_beat=12, limit=None):
        self.steps_per_beat = steps_per_beat
        try:
          self.pm = pretty_midi.PrettyMIDI(path)
        except OSError as e:
          print(f"Error reading MIDI file: {e}")
          return
        self.dbs = self.pm.get_downbeats().tolist() + [self.pm.get_end_time()] # dbs := downbeats
        self.tokens_seqs = self._tokenize(limit=limit)

    def __call__(self):
        return ' '.join(self.tokens)

    def _time_to_step(self, time):
        return round(self.pm.time_to_tick(time) / self.pm.resolution * self.steps_per_beat)

    def _event_to_tokens(self, event):
        if event in ('bar', 'beat'):
            return [event]
        elif isinstance(event, pretty_midi.containers.Note):
            return [f'note_{event.pitch}', f'len_{self._time_to_step(event.end) - self._time_to_step(event.start)}']

    def _trim_note(self, note, start, end):
        n = copy.copy(note)
        n.start, n.end = max(n.start, start), min(n.end, end)
        return n

    def _tokenize(self, start_measure=1, end_measure=None, limit=None):
        start, end = self.dbs[start_measure - 1], self.dbs[end_measure or -1]

        notes = []
        for inst in self.pm.instruments:
            notes += inst.notes
        notes.sort(key=lambda x: (x.start, -x.pitch))

        events = []
        events += [(self._time_to_step(db), 'bar') for db in self.dbs if start <= db < end]
        events += [(self._time_to_step(b), 'beat') for b in set(self.pm.get_beats()) - set(self.dbs) if start <= b < end] # beats without downbeats
        events += [(self._time_to_step(max(n.start, start)), self._trim_note(n, start, end)) for n in notes if start <= n.start < end or start < n.end <= end]
        events.sort(key=lambda x: x[0])

        tokens = []
        last_beat = 0

        try:
            for step, event in events:
                if event in ('bar', 'beat'):
                    last_beat = step
                if step - last_beat:
                    tokens.append(f'pos_{step - last_beat}')
                tokens += self._event_to_tokens(event)
        except Exception as e:
            print(f"Error while translating events to tokens: {e}")

        if limit is None:
            return [tokens]
        tokens = np.array(tokens)
        num_chunks = len(tokens) // limit
        return tokens[:num_chunks * limit].reshape(-1, limit)

    def measures(self, start_measure=1, end_measure=None):
        return self._tokenize(start_measure, end_measure)

## Tokens2MIDI

In [6]:
class TokensToMidi:
    def __init__(self, tokens, steps_per_beat=12, ticks_per_beat=960, tempo=120):
        self.tokens = tokens
        self.steps_per_beat = steps_per_beat
        self.ticks_per_step = ticks_per_beat // steps_per_beat
        self.tempo = tempo
        self.ticks_per_beat = ticks_per_beat

    def _ticks_to_time(self, ticks):
        return ticks * 60 / (self.tempo * self.ticks_per_beat)

    def generate_midi(self, path):
        pm = pretty_midi.PrettyMIDI(initial_tempo=self.tempo)
        instrument = pretty_midi.Instrument(program=38)

        time = 0
        last_beat = 0

        i = 0
        while i < len(self.tokens):
            token = self.tokens[i]

            if token == "bar":
                time += self._ticks_to_time(self.ticks_per_step * self.steps_per_beat)
                last_beat = time
            elif token == "beat":
                time = last_beat
                last_beat = time
            elif token.startswith("pos_"):
                position = int(token.split("_")[1])
                time = last_beat + self._ticks_to_time(self.ticks_per_step * position)
            elif token.startswith("note_"):
                pitch = int(token.split("_")[1])
                try:
                  length_token = self.tokens[i + 1]
                  length = int(length_token.split("_")[1])
                except IndexError:
                  length = 1
                duration = self._ticks_to_time(self.ticks_per_step * length)

                note = pretty_midi.Note(
                    velocity=100,
                    pitch=pitch,
                    start=time,
                    end=time + duration
                )
                instrument.notes.append(note)

                i += 1
            i += 1

        pm.instruments.append(instrument)
        pm.write(path)
        return pm

In [7]:
from sklearn.preprocessing import LabelEncoder
import numpy as np
from tqdm import tqdm

class Tokenizer():
    def __init__(self, limit=None):
        self._encoder = LabelEncoder()
        self.PAD_id = 0
        self.BOS_id = 1
        self.EOS_id = 2
        self.limit = limit

    def _tokenize(self, midi_paths):
      """
      midi_paths: list of paths to MIDI files
      returns: list of lists of string tokens
      """
      tokens = []
      for path in tqdm(midi_paths):
        try:
          seqs = midi_to_tokens(str(path), steps_per_beat=12, limit=self.limit).tokens_seqs
          for seq in seqs:
              seq_list = seq.tolist()
              seq_list.insert(0, "Start")
              seq_list.append("End")
              tokens.append(np.array(seq_list))
        except AttributeError:
          print(f"Error reading MIDI file: {path}")
          continue
        except Exception as e:
          print(f"There was an unexpected error: {e}")
          continue
      return tokens

    def fit_and_encode(self, midi_paths):
      tokens = self._tokenize(midi_paths)
      flattened_array = np.concatenate([np.array(sublist).flatten() for sublist in tokens])
      self._encoder.fit(flattened_array)
      transformed = [self._encoder.transform(i) for i in tokens]
      self.EOS_id = self._encoder.transform(["End"])[0]
      self.BOS_id = self._encoder.transform(["Start"])[0]
      self.PAD_id = self._encoder.classes_.shape[0]
      self._encoder.classes_ = np.append(self._encoder.classes_, ["Pad"])
      return transformed

    def encode(self, midi_paths):
      tokens = self._tokenize(midi_paths)
      return [self._encoder.transform(i) for i in tokens]

    def decode(self, encoded_tokens, path="reconstructed_midi.mid"):
      string_tokens = [self._encoder.inverse_transform(i) for i in encoded_tokens]
      for i in range(len(string_tokens)):
        midi_reconstructor = TokensToMidi(string_tokens[i])
        midi_reconstructor.generate_midi(str(i) + path)

    def pad(self, encoded_tokens):
      self._seq_length = max(len(arr) for arr in encoded_tokens)
      return np.array([np.pad(arr, (self._seq_length - len(arr), 0), mode='constant', constant_values=self.PAD_id) for arr in encoded_tokens])

    def save(self, path: str):
      np.savetxt(path, self.encoder.classes_, fmt="%s")

    def load(self, path: str):
      self._encoder.classes_ = np.loadtxt(path, dtype="str")

    @property
    def encoder(self):
      return self._encoder

    @property
    def vocab_size(self):
      return self._encoder.classes_.shape[0]

    @property
    def seq_length(self):
      return self._seq_length

    @property
    def pad_id(self):
      return self.PAD_id

    @property
    def bos_id(self):
      return self.BOS_id

    @property
    def eos_id(self):
      return self.EOS_id


## Download trained Tokenizer and tokens

In [13]:
!gdown 1QrMzoYewqyxv2TJEgwS53-8l6rlpi4RG
!gdown 1dwNkvRopC8gIpDzD2iEFZADe4aeF-w_F

!mv tokens_lim1000_files2500 tokens.txt
!mv tokenizer_lim1000_files2500 tokenizer.txt

## Load a previously fitted tokenizer and...

In [14]:
tok = Tokenizer(limit=1000)
tok.load("tokenizer.txt")

## ...load already tokenized data...

In [15]:
n_sequences = 7000
padded_tokens = np.loadtxt("tokens.txt")[:n_sequences]
print(f"Loaded {padded_tokens.shape[0]} tokenized sequences")

Loaded tokenized version of 15389 files


## ...**OR** fit the tokenizer...

In [9]:
tok = Tokenizer(limit=500)
encoded_tokens = tok.fit_and_encode(midis[:50])

print(f"PAD_id is {tok.pad_id}")
print(f"BOS_id is {tok.bos_id}")
print(f"EOS_id is {tok.eos_id}")
print(f"Vocab size is {tok.vocab_size}")

100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


PAD_id is 297
BOS_id is 1
EOS_id is 0
Vocab size is 298


## ...and add the Padding

In [10]:
padded_tokens = tok.pad(encoded_tokens)
print(f"Maximum sequence len is {tok.seq_length}")

Maximum sequence len is 502


# Model creation

In [None]:
#@title Install `keras_nlp`
!pip install keras_nlp

In [None]:
import keras_nlp.layers as nlp_layers

def create_transformer(vocab_size, seq_len, embedding_dim, num_heads, dff, num_layers):
  # Input
    inputs = tf.keras.Input(shape=(seq_len,))

    # Embedding
    embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs)

    # Encoder
    encoder = nlp_layers.TransformerEncoder(num_heads=num_heads, intermediate_dim=dff)(embedding)

    # Decoder
    decoder = nlp_layers.TransformerDecoder(num_heads=num_heads, intermediate_dim=dff)(embedding, encoder)

    # Output
    outputs = tf.keras.layers.Dense(vocab_size, activation='softmax')(decoder)

    # Crea il modello
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    return model, encoder, decoder

In [None]:
vocab_size = tok.vocab_size
seq_len = padded_tokens.shape[1]
import tensorflow as tf


model, encoder, decoder = create_transformer(vocab_size=vocab_size,
                                             seq_len=seq_len,
                                             embedding_dim=512,
                                             num_heads=8,
                                             dff=2048,
                                             num_layers=12)

## Instantiate a transformer

In [None]:
train_perc = 0.8
train_size = int(train_perc * padded_tokens.shape[0])

normalized_train_x = padded_tokens[:train_size]
normalized_val_x = padded_tokens[train_size:]

normalized_train_y = np.roll(normalized_train_x, shift=-1, axis=1)
normalized_val_y = np.roll(normalized_val_x, shift=-1, axis=1)

## Training the transformer

In [None]:
from tensorflow.keras.callbacks import EarlyStopping

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

model.fit(normalized_train_x, normalized_train_y,
          epochs=10,
          validation_data=(normalized_val_x, normalized_val_y),
          callbacks=[early_stopping],
          batch_size=32
          )

model.save("NesGen_v1.keras")

## Model loading (optional)

In [None]:
from tensorflow.keras.models import load_model

model = load_model("NesGen_v1.keras")

# Model usage

Here we use the trained model to generate new MIDI

## Utility functions

In [None]:
from tqdm import tqdm

def get_random_ids_from_dataset(dataset):
  return dataset[np.random.choice(dataset.shape[0])]

def sample_next_token(probabilities) -> int:
  # Ensure the probabilities sum to 1 (they should, but numerical issues might affect it)
  last_token_probs = probabilities[0, -1]
  last_token_probs /= last_token_probs.sum()
  return np.random.choice(len(last_token_probs), p=last_token_probs)

def next_token(model, seed_ids):
    probabilities = model.predict(seed_ids, verbose=0)
    next_token = sample_next_token(probabilities)
    return next_token

def generate_ids(model, seed_ids, eos_id, pad_id, bos_id, max_len=None, show_progress=True):
  if max_len is None:
    max_len = seed_ids.shape[1]
  seed = seed_ids
  generated_ids = []
  if not show_progress:
    iterations = range(max_len)
  else:
    iterations = tqdm(range(max_len))

  for _ in iterations:
    next_token_id = next_token(model, seed)
    generated_ids.append(next_token_id)
    if next_token_id == eos_id:
      break
    elif next_token_id == pad_id:
      continue

    seed = np.roll(seed, -1, axis=1)
    seed[0, -1] = next_token_id

  result = np.array(generated_ids)
  result[0] = bos_id
  result[-1] = eos_id
  return result


## Generate the IDs for the new sequence

In [None]:
seed = get_random_ids_from_dataset(normalized_test_y).reshape((1, max_len))
generated_ids = generate_ids(
    model,
    seed,
    eos_id=tokenizer["EOS_None"],
    pad_id=tokenizer["PAD_None"],
    bos_id=tokenizer["BOS_None"],
    max_len=100
)
print("\nGenerated\n" + str(generated_ids))

## Conversion from IDs to MIDI

### Utility function

In [None]:
import miditok

def ids_to_midi(
    ids: np.ndarray,
    tokenizer: miditok.tokenizations.remi.REMI,
    file_name: str ="result.mid",
    output_dir: str = "/content/"
  ):
  tokenizer([ids.astype(np.int32)]).dump_midi(Path(output_dir, file_name))

### Actual conversion

In [None]:
file_name = "result.mid"

print(f"Converting IDs to MIDI file: {file_name}...")

ids_to_midi(generated_ids, tokenizer, file_name=file_name)

print("DONE!")

### End to End utility

For generating multiple files in one call

In [None]:
def generate_midi(
    dataset,
    model,
    tokenizer,
    output_folder="/content/gen_midi",
    num_files=1,
    max_len=100
):
  if not os.path.exists(output_folder):
        os.makedirs(output_folder)
  for i in tqdm(range(num_files)):
    seed = get_random_ids_from_dataset(normalized_test_y).reshape((1, normalized_test_y.shape[1]))
    generated_ids = generate_ids(
        model,
        seed,
        eos_id=tokenizer["EOS_None"],
        pad_id=tokenizer["PAD_None"],
        bos_id=tokenizer["BOS_None"],
        max_len=max_len,
        show_progress=False
    )
    file_name = str(i)+".mid"/
    ids_to_midi(generated_ids, tokenizer, file_name=file_name, output_dir=output_folder)


#### Generate some files

In [None]:
n_files = 10
generate_midi(dataset_test,
               model,
               tokenizer,
               output_folder="/content/gen_midi/",
               num_files=n_files,
               max_len=100
               )

# 2nd PART - Playing the MIDI

## Utility functions

In [None]:
def random_file(root, keyword=None):
    import glob
    import os
    import random
    mid_files = glob.glob(os.path.join(root, "**", "*.mid"), recursive=True)
    if keyword is not None:
      mid_files = [file for file in mid_files if keyword in file.lower()]
    return random.choice(mid_files)

## Libraries

In [None]:
#@title Installing the required libraries
%%capture
!apt-get update -qq && apt-get install -y fluidsynth
!pip install pretty_midi midi-clip

# GS2
!gdown 1wlpTIS70nQHMrYBjDT0M6nyg07kUejUv
!unzip GeneralUser_GS_v2.0.0--doc_r2.zip
!rm -rf GeneralUser_GS_v2.0.0--doc_r2.zip support documentation demo\ MIDIs
!mv GeneralUser\ GS\ v2.0.0.sf2 guGS.sf2

# PICONICA
!gdown 1uk51T9Gvo1n2JRl3_CHCg2FVGWiNI4qJ

# Utility library
!wget https://raw.githubusercontent.com/roostico/NesGen/refs/heads/main/utility.py

from utility import *

## Using generated MIDI

In [None]:
#@title Show info

generated_midi_path = "gen_midi/3.mid" # @param {type:"string"}

print("Midi info:")
show_midi_info(generated_midi_path)

In [None]:
#@title Play the MIDI

generated_midi_path = "gen_midi/4.mid" # @param {type:"string"}
soundfont = "PICONICA.sf2" # @param ["PICONICA.sf2", "guGS.sf2", "PokeDP.sf2"]

playMidi(generated_midi_path, soundfont_path=soundfont)

# Extras (playing other MIDIs)

## Play a random MIDI of the Lakh dataset

In [None]:
path = random_file(dataset_path)
print("Converting: " + path)
print("Midi info:")
show_midi_info(path)
print("Synthetized:")
playMidi(path)

## Play a random MIDI of the NESMDB dataset

In [None]:
path = random_file("nesmdb_midi")
print("Converting: " + path)
print("Midi info:")
show_midi_info(path)
print("Synthetized:")
playMidi(path, soundfont_path="PICONICA.sf2", velocity_change=30, gain=1)