In [1]:
#@title Get a smaller version of the Lakh MIDI Dataset v0.1
%%capture
!wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
!tar xvf clean_midi.tar.gz
!rm clean_midi.tar.gz

dataset_path = "/content/clean_midi"

In [2]:
%%capture
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m2.8/5.6 MB[0m [31m85.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m5.3/5.6 MB[0m [31m77.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m53.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (se

In [3]:
from pathlib import Path

# Paths to the files of the dataset
midi_paths = list(Path("clean_midi").resolve().glob("**/*.mid"))

## Cleaning of files

In [6]:
import os
import shutil

midis_dir = "midis"
os.makedirs(midis_dir, exist_ok=True)

for i, midi_path in enumerate(midi_paths):
  new_midi_path = os.path.join(midis_dir, f"{i}.mid")
  shutil.move(str(midi_path), new_midi_path)

midis = list(Path("midis").resolve().glob("**/*.mid"))

## MIDI2Tokens

In [12]:
import copy
import pretty_midi

class midi_to_tokens():
    def __init__(self, path, steps_per_beat=12):
        self.steps_per_beat = steps_per_beat
        try:
          self.pm = pretty_midi.PrettyMIDI(path)
        except OSError as e:
          print(f"Error reading MIDI file: {e}")
          return
        self.dbs = self.pm.get_downbeats().tolist() + [self.pm.get_end_time()] # dbs := downbeats
        self.tokens = self._tokenize()

    def __call__(self):
        return ' '.join(self.tokens)

    def _time_to_step(self, time):
        return round(self.pm.time_to_tick(time) / self.pm.resolution * self.steps_per_beat)

    def _event_to_tokens(self, event):
        if event in ('bar', 'beat'):
            return [event]
        elif isinstance(event, pretty_midi.containers.Note):
            return [f'note_{event.pitch}', f'len_{self._time_to_step(event.end) - self._time_to_step(event.start)}']

    def _trim_note(self, note, start, end):
        n = copy.copy(note)
        n.start, n.end = max(n.start, start), min(n.end, end)
        return n

    def _tokenize(self, start_measure=1, end_measure=None):
        start, end = self.dbs[start_measure - 1], self.dbs[end_measure or -1]

        notes = []
        for inst in self.pm.instruments:
            notes += inst.notes
        notes.sort(key=lambda x: (x.start, -x.pitch))

        events = []
        events += [(self._time_to_step(db), 'bar') for db in self.dbs if start <= db < end]
        events += [(self._time_to_step(b), 'beat') for b in set(self.pm.get_beats()) - set(self.dbs) if start <= b < end] # beats without downbeats
        events += [(self._time_to_step(max(n.start, start)), self._trim_note(n, start, end)) for n in notes if start <= n.start < end or start < n.end <= end]
        events.sort(key=lambda x: x[0])

        tokens = []
        last_beat = 0
        for step, event in events:
            if event in ('bar', 'beat'):
                last_beat = step
            if step - last_beat:
                tokens.append(f'pos_{step - last_beat}')
            tokens += self._event_to_tokens(event)

        return tokens

    def measures(self, start_measure=1, end_measure=None):
        return self._tokenize(start_measure, end_measure)

## Tokens2MIDI

In [15]:
class TokensToMidi:
    def __init__(self, tokens, steps_per_beat=12, ticks_per_beat=960, tempo=120):
        self.tokens = tokens
        self.steps_per_beat = steps_per_beat
        self.ticks_per_step = ticks_per_beat // steps_per_beat
        self.tempo = tempo
        self.ticks_per_beat = ticks_per_beat

    def _ticks_to_time(self, ticks):
        return ticks * 60 / (self.tempo * self.ticks_per_beat)

    def generate_midi(self, path):
        pm = pretty_midi.PrettyMIDI(initial_tempo=self.tempo)
        instrument = pretty_midi.Instrument(program=38)

        time = 0
        last_beat = 0

        i = 0
        while i < len(self.tokens):
            token = self.tokens[i]

            if token == "bar":
                time += self._ticks_to_time(self.ticks_per_step * self.steps_per_beat)
                last_beat = time
            elif token == "beat":
                time = last_beat
                last_beat = time
            elif token.startswith("pos_"):
                position = int(token.split("_")[1])
                time = last_beat + self._ticks_to_time(self.ticks_per_step * position)
            elif token.startswith("note_"):
                pitch = int(token.split("_")[1])
                length_token = self.tokens[i + 1]
                length = int(length_token.split("_")[1])
                duration = self._ticks_to_time(self.ticks_per_step * length)

                note = pretty_midi.Note(
                    velocity=100,
                    pitch=pitch,
                    start=time,
                    end=time + duration
                )
                instrument.notes.append(note)

                i += 1
            i += 1

        pm.instruments.append(instrument)
        pm.write(path)
        return pm

## Define a utility class "Tokenizer"

In [19]:
from sklearn.preprocessing import LabelEncoder
import numpy as np

class Tokenizer():
    def __init__(self):
        self._encoder = LabelEncoder()
        self.PAD_id = 0
        self.BOS_id = 1
        self.EOS_id = 2

    def _tokenize(self, midi_paths):
      """
      midi_paths: list of paths to MIDI files
      returns: list of lists of string tokens
      """
      tokens = []
      for path in midi_paths:
        try:
          toAdd = midi_to_tokens(str(path), steps_per_beat=12).tokens
          toAdd.insert(0, "Start")
          toAdd.append("End")
          tokens.append(toAdd)
        except AttributeError:
          print(f"Error reading MIDI file: {path}")
          continue
      return tokens

    def fit_and_encode(self, midi_paths):
      tokens = self._tokenize(midi_paths)
      flattened_array = np.concatenate([np.array(sublist).flatten() for sublist in tokens])
      self._encoder.fit(flattened_array)
      transformed = [self._encoder.transform(i) for i in tokens]
      self.EOS_id = self._encoder.transform(["End"])[0]
      self.BOS_id = self._encoder.transform(["Start"])[0]
      self.PAD_id = self._encoder.classes_.shape[0]
      self._encoder.classes_ = np.append(self._encoder.classes_, ["Pad"])
      return transformed

    def encode(self, midi_paths):
      tokens = self._tokenize(midi_paths)
      return [self._encoder.transform(i) for i in tokens]

    def decode(self, encoded_tokens, path="reconstructed_midi.mid"):
      string_tokens = [self._encoder.inverse_transform(i) for i in encoded_tokens]
      midi_reconstructor = TokensToMidi(string_tokens)
      midi_reconstructor.generate_midi(path)

    def pad(self, encoded_tokens):
      self._seq_length = max(len(arr) for arr in encoded_tokens)
      return np.array([np.pad(arr, (self._seq_length - len(arr), 0), mode='constant', constant_values=self.PAD_id) for arr in encoded_tokens])

    @property
    def encoder(self):
      return self._encoder

    @property
    def vocab_size(self):
      return self._encoder.classes_.shape[0]

    @property
    def seq_length(self):
      return self._seq_length

    @property
    def pad_id(self):
      return self.PAD_id

    @property
    def bos_id(self):
      return self.BOS_id

    @property
    def eos_id(self):
      return self.EOS_id


## Fit the tokenizer

In [20]:
tok = Tokenizer()
encoded_tokens = tok.fit_and_encode(midis[:100])

print(f"PAD_id is {tok.pad_id}")
print(f"BOS_id is {tok.bos_id}")
print(f"EOS_id is {tok.eos_id}")
print(f"Vocab size is {tok.vocab_size}")



Error reading MIDI file: running status without last_status
Error reading MIDI file: /content/midis/16907.mid
Error reading MIDI file: data byte must be in range 0..127
Error reading MIDI file: /content/midis/2651.mid
PAD_id is 423
BOS_id is 1
EOS_id is 0
Vocab size is 424


## Padding

In [22]:
padded_tokens = tok.pad(encoded_tokens)

---

# Training

In [None]:
#@title Install `keras_nlp`
%%capture
!pip install keras_nlp

In [None]:
import keras_nlp.layers as nlp_layers

def create_transformer(vocab_size, seq_len, embedding_dim, num_heads, dff, num_layers):
  # Input
    inputs = tf.keras.Input(shape=(seq_len,))

    # Embedding
    embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs)

    # Encoder
    encoder = nlp_layers.TransformerEncoder(num_heads=num_heads, intermediate_dim=dff)(embedding)

    # Decoder
    decoder = nlp_layers.TransformerDecoder(num_heads=num_heads, intermediate_dim=dff)(embedding, encoder)

    # Output
    outputs = tf.keras.layers.Dense(vocab_size, activation='softmax')(decoder)

    # Crea il modello
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    return model, encoder, decoder

In [24]:
vocab_size = tok.vocab_size
seq_len = tok.seq_length
import tensorflow as tf


model, encoder, decoder = create_transformer(vocab_size=vocab_size,
                                             seq_len=seq_len,
                                             embedding_dim=256,
                                             num_heads=8,
                                             dff=1024,
                                             num_layers=6)

In [25]:
normalized_train_x = padded_tokens[:70]
normalized_val_x = padded_tokens[70:]

normalized_train_y = np.roll(normalized_train_x, shift=-1, axis=1)
normalized_val_y = np.roll(normalized_val_x, shift=-1, axis=1)


In [29]:
from tensorflow.keras.callbacks import EarlyStopping

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

model.fit(normalized_train_x, normalized_train_y,
          epochs=5,
          validation_data=(normalized_val_x, normalized_val_y),
          callbacks=[early_stopping],
          batch_size=32
          )

model.save("NesGen_v1.keras")

Epoch 1/5


ResourceExhaustedError: Graph execution error:

Detected at node StatefulPartitionedCall defined at (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code

  File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>

  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start

  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start

  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>

  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "<ipython-input-29-8a84f8330f17>", line 6, in <cell line: 6>

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 320, in fit

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 121, in one_step_on_iterator

Out of memory while trying to allocate 1198919140352 bytes.
	 [[{{node StatefulPartitionedCall}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_one_step_on_iterator_9343]