<a href="https://colab.research.google.com/github/roostico/NesGen/blob/main/NESGen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Getting the dataset

In [None]:
#@title Get the full version of the Lakh MIDI Dataset v0.1
!wget http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz
!tar xvf lmd_full.tar.gz
!rm lmd_full.tar.gz

dataset_path = "/content/lmd_full"

In [None]:
#@title Get a smaller version of the Lakh MIDI Dataset v0.1
!wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
!tar xvf clean_midi.tar.gz
!rm clean_midi.tar.gz

dataset_path = "/content/clean_midi"

In [None]:
#@title Get the NESMDB dataset
!gdown 1gIli7G1wu0QWDLzRc-CPWB8C4Hu0XVn3
!unzip nesmdb_midi.zip
!rm nesmdb_midi.zip

# Tokenization

## Installing MidiTok

In [None]:
!pip install miditok

## Setting up the tokenizer

In [None]:
from miditok import REMI, TokenizerConfig  # here we choose to use REMI
from pathlib import Path

# Creates the tokenizer and list the file paths
tokenizer = REMI()  # using defaults parameters (constants.py)


## Augmenting and splitting the dataset

In [None]:
from random import shuffle

from miditok.data_augmentation import augment_dataset
from miditok.utils import split_files_for_training

# Split the dataset into train/valid/test subsets, with 15% of the data for each of the two latter
midi_paths = list(Path(dataset_path).glob("**/*.mid"))[:1000]
total_num_files = len(midi_paths)
num_files_valid = round(total_num_files * 0.15)
num_files_test = round(total_num_files * 0.15)
shuffle(midi_paths)
midi_paths_valid = midi_paths[:num_files_valid]
midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test]
midi_paths_train = midi_paths[num_files_valid + num_files_test:]

# Chunk MIDIs and perform data augmentation on each subset independently
for files_paths, subset_name in (
    (midi_paths_train, "train"), (midi_paths_valid, "valid"), (midi_paths_test, "test")
):

    # Split the MIDIs into chunks of sizes approximately about 1024 tokens
    subset_chunks_dir = Path(f"dataset_{subset_name}")
    split_files_for_training(
        files_paths=files_paths,
        tokenizer=tokenizer,
        save_dir=subset_chunks_dir,
        max_seq_len=1024,
        num_overlap_bars=2,
    )

    # Perform data augmentation
    augment_dataset(
        subset_chunks_dir,
        pitch_offsets=[-12, 12],
        velocity_offsets=[-4, 4],
        duration_offsets=[-0.5, 0.5],
    )

## Preparing data loading

In [None]:
from pathlib import Path
from miditok.pytorch_data import DatasetMIDI, DataCollator
from torch.utils.data import DataLoader

midi_paths = list(Path("dataset_train").glob("**/*.mid")) + list(Path("dataset_valid").glob("**/*.mid")) + list(Path("dataset_test").glob("**/*.mid"))

# A validation method to discard MIDIs we do not want
# It can also be used for custom pre-processing, for instance if you want to merge
# some tracks before tokenizing a MIDI file
def midi_valid(midi) -> bool:
    if any(ts.numerator != 4 for ts in midi.time_signature_changes):
        return False  # time signature different from 4/*, 4 beats per bar
    return True

# Builds the vocabulary with BPE
# tokenizer.train(vocab_size=30000, files_paths=midi_paths)
dataset = DatasetMIDI(
    files_paths=midi_paths,
    tokenizer=tokenizer,
    max_seq_len=128,
    bos_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer["BOS_None"],
)

collator = DataCollator(tokenizer.pad_token_id)
data_loader = DataLoader(dataset=dataset, collate_fn=collator)

In [None]:
!pip install keras_nlp

## Building an example model

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from keras_nlp import layers as nlp_layers

def build_transformer_model(vocab_size, seq_length, num_heads, ff_dim):
    # Input layers for token ids and attention mask
    input_ids = layers.Input(shape=(seq_length,), dtype='int32', name='input_ids')
    attention_mask = layers.Input(shape=(seq_length,), dtype='int32', name='attention_mask')

    # Embedding layer for token ids
    embedding_layer = layers.Embedding(input_dim=vocab_size, output_dim=128)(input_ids)

    # Transformer Encoder
    transformer_layer = nlp_layers.TransformerEncoder(
        num_heads=num_heads,
        intermediate_dim=ff_dim,
        dropout=0.1
    )(embedding_layer)

    # Output layer (logits for each token in the sequence)
    output = layers.Dense(vocab_size, activation='softmax')(transformer_layer)

    # Define the model
    model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=output)
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# Example usage
vocab_size = 30522  # Based on your tokenizer
seq_length = 128    # Maximum sequence length
num_heads = 8       # Number of attention heads
ff_dim = 512        # Feed-forward layer size

model = build_transformer_model(vocab_size, seq_length, num_heads, ff_dim)

## Example training

In [None]:
import numpy as np
from tqdm import tqdm
i = 0
# Assuming input_ids_np and labels_np are the input tokens and shifted labels
for batch in tqdm(data_loader):
    input_ids_np = batch['input_ids'].cpu().numpy()
    attention_mask_np = batch['attention_mask'].cpu().numpy()
    labels_np = np.roll(input_ids_np, shift=-1, axis=1)  # Shifted input_ids as labels
    labels_np[:, -1] = 0  # Optionally mask the last token

    # Train the Keras model on the batch
    model.train_on_batch([input_ids_np, attention_mask_np], labels_np)
    i = i + 1
    if i == 50:
      break

## Utility function to generate a new sequence of tokens

In [None]:
def generate_sequence(model, tokenizer, max_length, prompt, attention_mask, vocab_size):
    """
    Generate a sequence using the Transformer model with greedy decoding.

    Parameters:
    - model: Keras transformer model
    - tokenizer: Tokenizer used to convert tokens to IDs and vice-versa
    - max_length: Maximum length of the generated sequence
    - prompt: Initial sequence of token IDs to start generating from
    - attention_mask: Initial attention mask for the input
    - vocab_size: The size of the vocabulary

    Returns:
    - Generated sequence of token IDs
    """
    generated_sequence = prompt  # Start with the initial prompt
    current_length = len(generated_sequence[0])  # Current length of the sequence

    while current_length < max_length:
        # Prepare inputs for the model (pad the sequence if necessary)
        input_ids_np = np.array(generated_sequence)
        attention_mask_np = np.array(attention_mask)

        # Predict the next token (get logits from the model)
        logits = model.predict([input_ids_np, attention_mask_np], verbose=0)

        # Take the last timestep logits (logits for the next token in sequence)
        next_token_logits = logits[:, current_length - 1, :]

        # Use greedy decoding: pick the token with the highest probability
        next_token_id = np.argmax(next_token_logits, axis=-1)

        # Append the predicted token to the sequence
        generated_sequence[0].append(next_token_id[0])

        # Update attention mask (append 1 for the new token)
        attention_mask[0].append(1)

        current_length += 1

        # Break if the end token is generated (assuming 0 is the end token)
        #if next_token_id[0] == tokenizer.eos_token_id:
        #    break

    return generated_sequence[0]

## Generation

In [None]:
prompt = [[4, 5, 100, 56, 49, 10]]  # Initial prompt in token IDs
attention_mask = [[1] * len(prompt[0])]  # Initial attention mask (1s for the tokens in the prompt)

max_length = 50  # Set the maximum length for the generated sequence
vocab_size = 30522  # Set your vocabulary size

# Generate a sequence
generated_ids = generate_sequence(model, tokenizer, max_length, prompt, attention_mask, vocab_size)

# Decode the generated token IDs back to text
#generated_text = tokenizer.decode(generated_ids)
print(generated_ids)

# Creation of the model

In [None]:
# ...

# Training of the model

In [None]:
# ...

# Utility functions

In [None]:
def random_file(root, keyword=None):
    import glob
    import os
    import random
    mid_files = glob.glob(os.path.join(root, "**", "*.mid"), recursive=True)
    if keyword is not None:
      mid_files = [file for file in mid_files if keyword in file.lower()]
    return random.choice(mid_files)

def generate_midi_from_tokens(tokens, tokenizer, output_path):
  from pathlib import Path
  # Convert to MIDI and save it
  generated_midi = tokenizer(tokens)  # MidiTok can handle PyTorch/Numpy/Tensorflow tensors
  generated_midi.dump_midi(Path(output_path))

In [None]:
generate_midi_from_tokens([generated_ids], tokenizer, "generated.mid")

# MIDI playing

## Installing the required libraries

In [None]:
!apt-get update -qq && apt-get install -y fluidsynth
!pip install pretty_midi midi-clip

## Download example Soundfonts (GeneralUser GS v2 and PICONICA)

In [None]:
!gdown 1wlpTIS70nQHMrYBjDT0M6nyg07kUejUv
!unzip GeneralUser_GS_v2.0.0--doc_r2.zip
!rm -rf GeneralUser_GS_v2.0.0--doc_r2.zip support documentation demo\ MIDIs
!mv GeneralUser\ GS\ v2.0.0.sf2 guGS.sf2

# PICONICA
!gdown 1uk51T9Gvo1n2JRl3_CHCg2FVGWiNI4qJ

## Optional: download other soundfonts

In [None]:
# Pokemon
!gdown 1vDK_xH7WeAqQrrBFXfh4Q205x6oNhTQt

## Utility function to generate the audio on Colab

### Taken from https://github.com/bzamecnik/midi2audio/blob/master/midi2audio.py

In [None]:
import argparse
import os
import subprocess

__all__ = ['FluidSynth']

DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_GAIN = 0.2

class FluidSynth():
    def __init__(self, sound_font=DEFAULT_SOUND_FONT, sample_rate=DEFAULT_SAMPLE_RATE, gain=DEFAULT_GAIN):
        self.sample_rate = sample_rate
        self.sound_font = os.path.expanduser(sound_font)
        self.gain = gain

    def midi_to_audio(self, midi_file: str, audio_file: str, verbose=True):
        if verbose:
            stdout = None
        else:
            stdout = subprocess.DEVNULL
        subprocess.call(
            ['fluidsynth', '-ni', '-g', str(self.gain), self.sound_font, midi_file, '-F', audio_file, '-r', str(self.sample_rate)],
            stdout=stdout,
        )

    def play_midi(self, midi_file):
        subprocess.call(['fluidsynth', '-i', '-g', str(self.gain), self.sound_font, midi_file, '-r', str(self.sample_rate)])

### Other utility functions

In [None]:
import pretty_midi
import os
import librosa.display

def show_midi_info(midi_path, print_notes=False):
  midi_data = pretty_midi.PrettyMIDI(midi_path)
  print("Instruments: ", [instrument.name for instrument in midi_data.instruments])
  print("MIDI duration: {duration:.2f} seconds".format(duration=midi_data.get_end_time()))
  if print_notes:
    for instrument in midi_data.instruments:
      print(instrument.name)
      for note in instrument.notes:
        print(note.start, note.end, note.pitch, note.velocity)

def piano_roll(midi_path):
  plt.figure(figsize=(12, 4))
  plot_piano_roll(path, 24, 84)

def plot_piano_roll(path, start_pitch, end_pitch, fs=100):
    midi_data = pretty_midi.PrettyMIDI(path)
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(midi_data.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))

def change_midi_velocity(midi_path, output_path, delta=0): # Renamed the function to avoid name conflict
  midi_data = pretty_midi.PrettyMIDI(midi_path)
  if delta != 0:
    for instrument in midi_data.instruments:
      for note in instrument.notes:
        note.velocity += delta
    midi_data.write(output_path)

def convert_midi_to_wav(soundfont_path, midi_path, output_path, gain=None, velocity_change=0): # Renamed the argument
  change_midi_velocity(midi_path, "temp.mid", delta=velocity_change) # Call the renamed function
  FluidSynth(soundfont_path, gain=gain).midi_to_audio("temp.mid", output_path)
  os.remove("temp.mid")


def trim_midi(midi_path, start, end):
  import mido
  import midi_clip
  mid = mido.MidiFile(midi_path)
  trimmed_midi = midi_clip.midi_clip(mid, start, end)

  dir_name, base_name = os.path.split(midi_path)
  new_base_name = "trimmed_" + base_name
  output_path = os.path.join(dir_name, new_base_name)
  trimmed_midi.save(output_path)
  return output_path

def playMidi(midi_file_path,
             soundfont_path="/content/guGS.sf2",
             output_path="audio.wav",
             start=None,
             end=None,
             gain=DEFAULT_GAIN,
             velocity_change=0
             ):
    from IPython.display import Audio

    if start is not None and end is not None:
      midi_file_path = trim_midi(midi_file_path, start, end)
      convert_midi_to_wav(soundfont_path, midi_file_path, output_path, gain=gain, velocity_change=velocity_change)
      os.remove(midi_file_path)
    else:
      convert_midi_to_wav(soundfont_path, midi_file_path, output_path, gain=gain, velocity_change=velocity_change)
    return Audio(output_path)

In [None]:
import matplotlib.pyplot as plt

path = random_file(dataset_path)
print("Converting: " + path)
print("Midi info:")
show_midi_info(path)

## Play a random MIDI of the Lakh dataset

In [None]:
path = random_file(dataset_path)
print("Converting: " + path)
print("Midi info:")
show_midi_info(path)
print("Synthetized:")
playMidi(path)

## Play a random MIDI of the NESMDB dataset

In [None]:
path = random_file("nesmdb_midi")
print("Converting: " + path)
print("Midi info:")
show_midi_info(path)
print("Synthetized:")
playMidi(path, soundfont_path="PICONICA.sf2", velocity_change=30, gain=1)

In [None]:
show_midi_info("generated.mid")

In [None]:
playMidi("generated.mid", soundfont_path="PICONICA.sf2")