In [1]:
#@title Get a smaller version of the Lakh MIDI Dataset v0.1
%%capture
!wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
!tar xvf clean_midi.tar.gz
!rm clean_midi.tar.gz

dataset_path = "/content/clean_midi"

In [2]:
%%capture
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m2.8/5.6 MB[0m [31m85.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m5.3/5.6 MB[0m [31m77.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m53.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (se

In [3]:
from pathlib import Path

# Paths to the files of the dataset
midi_paths = list(Path("clean_midi").resolve().glob("**/*.mid"))

## Cleaning of files

In [6]:
import os
import shutil

midis_dir = "midis"
os.makedirs(midis_dir, exist_ok=True)

for i, midi_path in enumerate(midi_paths):
  new_midi_path = os.path.join(midis_dir, f"{i}.mid")
  shutil.move(str(midi_path), new_midi_path)

midis = list(Path("midis").resolve().glob("**/*.mid"))

## MIDI2Tokens

In [41]:
import copy
import pretty_midi

class midi_to_tokens():
    def __init__(self, path, steps_per_beat=12, limit=None):
        self.steps_per_beat = steps_per_beat
        try:
          self.pm = pretty_midi.PrettyMIDI(path)
        except OSError as e:
          print(f"Error reading MIDI file: {e}")
          return
        self.dbs = self.pm.get_downbeats().tolist() + [self.pm.get_end_time()] # dbs := downbeats
        self.tokens = self._tokenize(limit=limit)

    def __call__(self):
        return ' '.join(self.tokens)

    def _time_to_step(self, time):
        return round(self.pm.time_to_tick(time) / self.pm.resolution * self.steps_per_beat)

    def _event_to_tokens(self, event):
        if event in ('bar', 'beat'):
            return [event]
        elif isinstance(event, pretty_midi.containers.Note):
            return [f'note_{event.pitch}', f'len_{self._time_to_step(event.end) - self._time_to_step(event.start)}']

    def _trim_note(self, note, start, end):
        n = copy.copy(note)
        n.start, n.end = max(n.start, start), min(n.end, end)
        return n

    def _tokenize(self, start_measure=1, end_measure=None, limit=None):
        start, end = self.dbs[start_measure - 1], self.dbs[end_measure or -1]

        notes = []
        for inst in self.pm.instruments:
            notes += inst.notes
        notes.sort(key=lambda x: (x.start, -x.pitch))

        events = []
        events += [(self._time_to_step(db), 'bar') for db in self.dbs if start <= db < end]
        events += [(self._time_to_step(b), 'beat') for b in set(self.pm.get_beats()) - set(self.dbs) if start <= b < end] # beats without downbeats
        events += [(self._time_to_step(max(n.start, start)), self._trim_note(n, start, end)) for n in notes if start <= n.start < end or start < n.end <= end]
        events.sort(key=lambda x: x[0])

        tokens = []
        last_beat = 0
        if limit is not None:
            events = events[:limit]
        for step, event in events:
            if event in ('bar', 'beat'):
                last_beat = step
            if step - last_beat:
                tokens.append(f'pos_{step - last_beat}')
            tokens += self._event_to_tokens(event)

        return tokens

    def measures(self, start_measure=1, end_measure=None):
        return self._tokenize(start_measure, end_measure)

## Tokens2MIDI

In [106]:
class TokensToMidi:
    def __init__(self, tokens, steps_per_beat=12, ticks_per_beat=960, tempo=120):
        self.tokens = tokens
        self.steps_per_beat = steps_per_beat
        self.ticks_per_step = ticks_per_beat // steps_per_beat
        self.tempo = tempo
        self.ticks_per_beat = ticks_per_beat

    def _ticks_to_time(self, ticks):
        return ticks * 60 / (self.tempo * self.ticks_per_beat)

    def generate_midi(self, path):
        pm = pretty_midi.PrettyMIDI(initial_tempo=self.tempo)
        instrument = pretty_midi.Instrument(program=38)

        time = 0
        last_beat = 0

        i = 0
        while i < len(self.tokens):
            token = self.tokens[i]

            if token == "bar":
                time += self._ticks_to_time(self.ticks_per_step * self.steps_per_beat)
                last_beat = time
            elif token == "beat":
                time = last_beat
                last_beat = time
            elif token.startswith("pos_"):
                position = int(token.split("_")[1])
                time = last_beat + self._ticks_to_time(self.ticks_per_step * position)
            elif token.startswith("note_"):
                pitch = int(token.split("_")[1])
                try:
                  length_token = self.tokens[i + 1]
                  length = int(length_token.split("_")[1])
                except IndexError:
                  length = 1
                duration = self._ticks_to_time(self.ticks_per_step * length)

                note = pretty_midi.Note(
                    velocity=100,
                    pitch=pitch,
                    start=time,
                    end=time + duration
                )
                instrument.notes.append(note)

                i += 1
            i += 1

        pm.instruments.append(instrument)
        pm.write(path)
        return pm

## Define a utility class "Tokenizer"

In [107]:
from sklearn.preprocessing import LabelEncoder
import numpy as np

class Tokenizer():
    def __init__(self, limit=None):
        self._encoder = LabelEncoder()
        self.PAD_id = 0
        self.BOS_id = 1
        self.EOS_id = 2
        self.limit = limit

    def _tokenize(self, midi_paths):
      """
      midi_paths: list of paths to MIDI files
      returns: list of lists of string tokens
      """
      tokens = []
      for path in midi_paths:
        try:
          toAdd = midi_to_tokens(str(path), steps_per_beat=12, limit=self.limit).tokens
          toAdd.insert(0, "Start")
          toAdd.append("End")
          tokens.append(toAdd)
        except AttributeError:
          print(f"Error reading MIDI file: {path}")
          continue
      return tokens

    def fit_and_encode(self, midi_paths):
      tokens = self._tokenize(midi_paths)
      flattened_array = np.concatenate([np.array(sublist).flatten() for sublist in tokens])
      self._encoder.fit(flattened_array)
      transformed = [self._encoder.transform(i) for i in tokens]
      self.EOS_id = self._encoder.transform(["End"])[0]
      self.BOS_id = self._encoder.transform(["Start"])[0]
      self.PAD_id = self._encoder.classes_.shape[0]
      self._encoder.classes_ = np.append(self._encoder.classes_, ["Pad"])
      return transformed

    def encode(self, midi_paths):
      tokens = self._tokenize(midi_paths)
      return [self._encoder.transform(i) for i in tokens]

    def decode(self, encoded_tokens, path="reconstructed_midi.mid"):
      string_tokens = [self._encoder.inverse_transform(i) for i in encoded_tokens]
      for i in range(len(string_tokens)):
        midi_reconstructor = TokensToMidi(string_tokens[i])
        midi_reconstructor.generate_midi(str(i) + path)

    def pad(self, encoded_tokens):
      self._seq_length = max(len(arr) for arr in encoded_tokens)
      return np.array([np.pad(arr, (self._seq_length - len(arr), 0), mode='constant', constant_values=self.PAD_id) for arr in encoded_tokens])

    @property
    def encoder(self):
      return self._encoder

    @property
    def vocab_size(self):
      return self._encoder.classes_.shape[0]

    @property
    def seq_length(self):
      return self._seq_length

    @property
    def pad_id(self):
      return self.PAD_id

    @property
    def bos_id(self):
      return self.BOS_id

    @property
    def eos_id(self):
      return self.EOS_id


## Fit the tokenizer

In [108]:
tok = Tokenizer(limit=500)
encoded_tokens = tok.fit_and_encode(midis[:100])

print(f"PAD_id is {tok.pad_id}")
print(f"BOS_id is {tok.bos_id}")
print(f"EOS_id is {tok.eos_id}")
print(f"Vocab size is {tok.vocab_size}")



Error reading MIDI file: running status without last_status
Error reading MIDI file: /content/midis/16907.mid
Error reading MIDI file: data byte must be in range 0..127
Error reading MIDI file: /content/midis/2651.mid
PAD_id is 285
BOS_id is 1
EOS_id is 0
Vocab size is 286


## Padding

In [109]:
padded_tokens = tok.pad(encoded_tokens)
print(f"Maximum sequence len is {tok.seq_length}")

Maximum sequence len is 1412


---

# Training

In [None]:
#@title Install `keras_nlp`
%%capture
!pip install keras_nlp

In [44]:
import keras_nlp.layers as nlp_layers

def create_transformer(vocab_size, seq_len, embedding_dim, num_heads, dff, num_layers):
  # Input
    inputs = tf.keras.Input(shape=(seq_len,))

    # Embedding
    embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs)

    # Encoder
    encoder = nlp_layers.TransformerEncoder(num_heads=num_heads, intermediate_dim=dff)(embedding)

    # Decoder
    decoder = nlp_layers.TransformerDecoder(num_heads=num_heads, intermediate_dim=dff)(embedding, encoder)

    # Output
    outputs = tf.keras.layers.Dense(vocab_size, activation='softmax')(decoder)

    # Crea il modello
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    return model, encoder, decoder

In [110]:
vocab_size = tok.vocab_size
seq_len = tok.seq_length
import tensorflow as tf


model, encoder, decoder = create_transformer(vocab_size=vocab_size,
                                             seq_len=seq_len,
                                             embedding_dim=256,
                                             num_heads=8,
                                             dff=1024,
                                             num_layers=6)

In [111]:
normalized_train_x = padded_tokens[:70]
normalized_val_x = padded_tokens[70:]

normalized_train_y = np.roll(normalized_train_x, shift=-1, axis=1)
normalized_val_y = np.roll(normalized_val_x, shift=-1, axis=1)


In [112]:
from tensorflow.keras.callbacks import EarlyStopping

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

model.fit(normalized_train_x, normalized_train_y,
          epochs=20,
          validation_data=(normalized_val_x, normalized_val_y),
          callbacks=[early_stopping],
          batch_size=32
          )

model.save("NesGen_v1.keras")

Epoch 1/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 5s/step - accuracy: 0.1985 - loss: 5.0963 - val_accuracy: 0.2399 - val_loss: 3.6791
Epoch 2/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 555ms/step - accuracy: 0.2731 - loss: 3.5083 - val_accuracy: 0.3113 - val_loss: 3.2561
Epoch 3/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 525ms/step - accuracy: 0.3155 - loss: 3.1637 - val_accuracy: 0.3134 - val_loss: 3.0335
Epoch 4/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 181ms/step - accuracy: 0.3116 - loss: 2.9556 - val_accuracy: 0.3120 - val_loss: 2.9702
Epoch 5/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 191ms/step - accuracy: 0.3335 - loss: 2.8419 - val_accuracy: 0.3663 - val_loss: 2.8402
Epoch 6/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 235ms/step - accuracy: 0.3618 - loss: 2.7185 - val_accuracy: 0.3731 - val_loss: 2.7678
Epoch 7/20
[1m3/3[0m [32m━━━━━━━━━━━━━━

---

# Generation

In [77]:
from tqdm import tqdm

def get_random_ids_from_dataset(dataset):
  return dataset[np.random.choice(dataset.shape[0])]

def sample_next_token(probabilities) -> int:
  # Ensure the probabilities sum to 1 (they should, but numerical issues might affect it)
  last_token_probs = probabilities[0, -1]
  last_token_probs /= last_token_probs.sum()
  return np.random.choice(len(last_token_probs), p=last_token_probs)

def next_token(model, seed_ids):
    probabilities = model.predict(seed_ids, verbose=0)
    next_token = sample_next_token(probabilities)
    return next_token

def generate_ids(model, seed_ids, eos_id, pad_id, bos_id, max_len=None, show_progress=True):
  if max_len is None:
    max_len = seed_ids.shape[1]
  seed = seed_ids
  generated_ids = []
  if not show_progress:
    iterations = range(max_len)
  else:
    iterations = tqdm(range(max_len))

  for _ in iterations:
    next_token_id = next_token(model, seed)
    generated_ids.append(next_token_id)
    if next_token_id == eos_id:
      break
    elif next_token_id == pad_id:
      continue

    seed = np.roll(seed, -1, axis=1)
    seed[0, -1] = next_token_id

  result = np.array(generated_ids)
  result[0] = bos_id
  result[-1] = eos_id
  return result

In [113]:
seed = get_random_ids_from_dataset(normalized_train_x).reshape((1, seq_len))
generated_ids = generate_ids(
    model,
    seed,
    eos_id=tok.eos_id,
    pad_id=tok.pad_id,
    bos_id=tok.bos_id,
    max_len=500
)
print("\nGenerated\n" + str(generated_ids))

100%|██████████| 500/500 [00:40<00:00, 12.35it/s]


Generated
[  1 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285
 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285
 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285
 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285
 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285 285
   1   2 229   4 281 201   4   3   2 228   4 283 239   4 199   4 233  91
 227  53 271 230   6 281 235  41 199   4 260 229   4 226  82 194   4 194
   4 283 213   4 260 229   4 187 112 263 225  72 276 226   5 279 208   4
 281 216   4 260 239   4 199   4 280 213   4 259 211  99 281 187  72 281
 222  72 268 228   4 281 219   5 281 187 213   5 281 229  53 271 204   4
 283 220  47 260 229   4 281 194   4 144 259 219 106   3 229   4 282   4
   3 234   4 207  99 281 194   4 199   4 201   4 282   4 235  86 281 229
  72 259 201   4 260 219  72 281 241  88 260 217 125 261 236   4 281 213
   4 260 223   5   2 190  99 281 194 112




In [114]:
def ids_to_midi(
    ids: np.ndarray,
    tokenizer: Tokenizer,
    file_name: str ="result.mid",
  ):
  tokenizer.decode([ids], file_name)

In [115]:
ids_to_midi(generated_ids, tok)

['Start' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad'
 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad'
 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad'
 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad'
 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad'
 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad'
 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad'
 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Pad' 'Start' 'bar' 'note_70' 'len_0'
 'pos_6' 'note_42' 'len_0' 'beat' 'bar' 'note_69' 'len_0' 'pos_8'
 'note_80' 'len_0' 'note_40' 'len_0' 'note_74' 'len_44' 'note_68' 'len_23'
 'pos_24' 'note_71' 'len_10' 'pos_6' 'note_76' 'len_19' 'note_40' 'len_0'
 'pos_10' 'note_70' 'len_0' 'note_67' 'len_38' 'note_35' 'len_0' 'note_35'
 'len_0' 'pos_8' 'note_54' 'len_0' 'pos_10' 'note_70' 'len_0' 'note_28'
 'len_6' 'pos_13' 'note_66' 'len_3' 'pos_3' 'note_67' 'l