In [6]:
#@title Get a smaller version of the Lakh MIDI Dataset v0.1
%%capture
!wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
!tar xvf clean_midi.tar.gz
!rm clean_midi.tar.gz
!rm -rf sample_data

dataset_path = "/content/clean_midi"

In [7]:
%%capture
!pip install pretty_midi
!pip install miditok

In [3]:
!pip install tensorflow==2.17
!pip install --upgrade keras
!pip install keras_nlp

Collecting keras
  Downloading keras-3.6.0-py3-none-any.whl.metadata (5.8 kB)
Downloading keras-3.6.0-py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: keras
  Attempting uninstall: keras
    Found existing installation: keras 3.5.0
    Uninstalling keras-3.5.0:
      Successfully uninstalled keras-3.5.0
Successfully installed keras-3.6.0
Collecting keras_nlp
  Downloading keras_nlp-0.17.0-py3-none-any.whl.metadata (1.2 kB)
Collecting keras-hub==0.17.0 (from keras_nlp)
  Downloading keras_hub-0.17.0-py3-none-any.whl.metadata (7.4 kB)
Collecting tensorflow-text (from keras-hub==0.17.0->keras_nlp)
  Downloading tensorflow_text-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting tensorflow<2.19,>=2.18.0 (from tensorflow-text->keras-hub==0.17.0->keras_nlp)
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_6

In [4]:
#@title Installing libraries to hear a MIDI
%%capture
!apt-get update -qq && apt-get install -y fluidsynth
!pip install pretty_midi midi-clip

# GS2
!gdown 1wlpTIS70nQHMrYBjDT0M6nyg07kUejUv
!unzip GeneralUser_GS_v2.0.0--doc_r2.zip
!rm -rf GeneralUser_GS_v2.0.0--doc_r2.zip support documentation demo\ MIDIs
!mv GeneralUser\ GS\ v2.0.0.sf2 guGS.sf2

# PICONICA
!gdown 1uk51T9Gvo1n2JRl3_CHCg2FVGWiNI4qJ

# Utility library
!wget https://raw.githubusercontent.com/roostico/NesGen/refs/heads/main/utility.py
!wget https://raw.githubusercontent.com/roostico/NesGen/refs/heads/main/transformer.py

from utility import *

## Import libraries

In [5]:
import os
import random
import shutil
from tqdm import tqdm
from pathlib import Path
import pretty_midi
import numpy as np
from miditok import REMI, TokenizerConfig
from utility import playMidi, show_midi_info
import json
import keras_nlp.layers as nlp_layers
from tensorflow import keras
import tensorflow as tf

In [8]:
skip = True

### Utility functions

In [9]:
def play_tokens(tokens: np.ndarray, tokenizer: REMI, delete_after: bool = True, show_info: bool = False):
  """
  Plays the given tokens, decoded using the given tokenizer
  """
  dumped_midi = "decoded.mid"
  tokenizer.decode([tokens]).dump_midi(dumped_midi)
  to_play = playMidi(dumped_midi)
  if show_info:
    show_midi_info(dumped_midi)
  if delete_after:
    os.remove(dumped_midi)
  return to_play


def random_filtered(collection: list, predicate):
  """
  Returns a random element from a collection that satisfies the given predicate.
  If no element satisfies the filter, returns None.
  """
  for elem in random.sample(collection, len(collection)):
    if predicate(elem):
      return elem

def lasts_less_than(midi_path: str, time_seconds: int) -> bool:
  """
  Returns true if the last note of the MIDI file is less than the given time in seconds.
  """
  return pretty_midi.PrettyMIDI(midi_path).get_end_time() <= time_seconds

## Move files and rename them

In [10]:
from pathlib import Path
import os
import shutil
import random

# Paths to the files of the dataset
midi_paths = list(Path(dataset_path).resolve().glob("**/*.mid"))

midis_dir = "midis"
os.makedirs(midis_dir, exist_ok=True)

for i, midi_path in enumerate(midi_paths):
  new_midi_path = os.path.join(midis_dir, f"{i}.mid")
  shutil.move(str(midi_path), new_midi_path)

midis = list(Path("midis").resolve().glob("**/*.mid"))

def sample():
  return str(random.choice(midis))

## Select a sample of these files

### Define a filtering function

In [11]:
def is_valid(file: str) -> bool:
  """Checks if a MIDI file is valid. If any of its instruments has no name,
  it is invalid.

  Args:
      file: The path to the MIDI file.

  Returns:
      True if the MIDI file is valid, False otherwise.
  """
  try:
    midi = pretty_midi.PrettyMIDI(file)
    if any([len(instrument.name) == 0 for instrument in midi.instruments]):
      return False
    return True
  except Exception as e:
    print(e)
    return False

### Move files

In [12]:
def copy_random_files(source_dir: str, dest_dir: str, num_files: int, is_file_valid) -> list:
  """Copies a specified number of random files from a source directory to a destination directory.

  Args:
      source_dir: The path to the source directory.
      dest_dir: The path to the destination directory.
      num_files: The number of files to move.
  """
  if not os.path.exists(source_dir):
    print(f"Error: Source directory '{source_dir}' not found.")
    return

  if os.path.exists(dest_dir):
    shutil.rmtree(dest_dir)

  os.makedirs(dest_dir, exist_ok=True)
  files = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]

  if len(files) < num_files:
    print(f"Warning: Only {len(files)} files found in '{source_dir}'. Moving all of them.")
    num_files = len(files)

  files_to_move = []
  i = 0
  with tqdm(total=num_files, position=0, leave=True) as pbar:
    while i < num_files:
      random_file = random.choice(files)
      if is_file_valid(os.path.join(source_dir, random_file)):
        files_to_move.append(random_file)
        i = i + 1
        pbar.update()
      else:
        continue

  result = []
  for file in files_to_move:
    source_path = os.path.join(source_dir, file)
    dest_path = os.path.join(dest_dir, file)
    shutil.copy(source_path, dest_path)
    result.append(dest_path)
  return result

source_directory = "midis"
destination_directory = "selected"
number_of_files_to_move = 500

sample = copy_random_files(source_directory, destination_directory, number_of_files_to_move, is_valid)
assert len(sample) == 500

 26%|██▌       | 129/500 [00:30<01:42,  3.61it/s]

no MTrk header at start of track


 31%|███▏      | 157/500 [00:37<01:53,  3.01it/s]

MIDI file has a largest tick of 33639950, it is likely corrupt


 37%|███▋      | 184/500 [00:52<01:16,  4.13it/s]

data byte must be in range 0..127


 41%|████      | 204/500 [01:01<02:00,  2.46it/s]

data byte must be in range 0..127


 47%|████▋     | 234/500 [01:07<00:58,  4.56it/s]

data byte must be in range 0..127


 49%|████▉     | 247/500 [01:10<00:45,  5.52it/s]

data byte must be in range 0..127


 71%|███████   | 356/500 [01:37<00:20,  6.87it/s]

running status without last_status


 75%|███████▍  | 373/500 [01:40<00:29,  4.36it/s]




 80%|████████  | 401/500 [01:49<00:24,  4.02it/s]

Could not decode key with 2 flats and mode 255


 82%|████████▏ | 408/500 [01:51<00:25,  3.61it/s]

data byte must be in range 0..127


100%|██████████| 500/500 [02:13<00:00,  3.74it/s]

data byte must be in range 0..127





### Listen one random MIDI from sample

In [13]:
if not skip:
  def valid_example(file: str) -> float:
    """
    Returns true if the midi file lasts less that 120 seconds and has more than 1 instrument
    """
    return lasts_less_than(file, 120) and len(pretty_midi.PrettyMIDI(file).instruments) > 1

  example = random_filtered(sample, valid_example)

  print(f"Showing file {example}")
  show_midi_info(example)
  playMidi(example)

## Pre-processing

In [14]:
TRACK_MIN_DENSITY_PERC = 0.20

def is_excluded(track_name):
  """
  Exclusion criteria for MIDI tracks
  """
  exclude_keywords = ["drum", "effect", "percussion"]
  return any(keyword.lower() in track_name.lower() for keyword in exclude_keywords)

def merge_tracks_to_single_instrument(input_file, output_file, target_channel=0):
    midi_data = pretty_midi.PrettyMIDI(input_file)
    merged_midi = pretty_midi.PrettyMIDI()
    merged_instrument = pretty_midi.Instrument(program=pretty_midi.instrument_name_to_program('Acoustic Grand Piano'), is_drum=False)

    mean_notes = map(lambda x: len(x.notes), midi_data.instruments)
    mean_notes = sum(mean_notes) / len(midi_data.instruments)

    for instrument in midi_data.instruments:
        track_name = instrument.name

        # Exclude drum instruments or effect instruments
        if instrument.is_drum or is_excluded(track_name):
            continue

        # Exclude instruments that have a low number of notes
        if len(instrument.notes) / mean_notes < TRACK_MIN_DENSITY_PERC:
            continue

        for note in instrument.notes:
            note.velocity = max(1, note.velocity)  # Ensure velocity is within MIDI range
            merged_instrument.notes.append(note)

    merged_instrument.notes.sort(key=lambda note: note.start)

    tempo_times, tempi = midi_data.get_tempo_changes()
    if len(tempi) > 0:
        merged_midi._tick_scales = midi_data._tick_scales  # Copy tempo-related timing

    merged_midi.instruments.append(merged_instrument)
    merged_midi.time_signature_changes = midi_data.time_signature_changes
    merged_midi.key_signature_changes = midi_data.key_signature_changes

    merged_midi.write(output_file)

if os.path.exists("pre-processed"):
  shutil.rmtree("pre-processed")

os.makedirs("pre-processed")

for file in tqdm(sample):
  try:
    merge_tracks_to_single_instrument(file, f"pre-processed/{os.path.basename(file)}")
  except Exception as e:
    print(f"There was an error: {e}")
    continue

processed = list(Path("pre-processed").resolve().glob("**/*.mid"))

100%|██████████| 500/500 [03:44<00:00,  2.23it/s]


### Listen the same file from before, but with merged MIDI tracks

In [15]:
if not skip:
  processed_example = os.path.join("pre-processed", os.path.basename(example))
  print(f"Showing file {processed_example}")
  show_midi_info(processed_example)
  playMidi(processed_example)

## Tokenizer



In [16]:
processed = list(Path("pre-processed").resolve().glob("**/*.mid"))

In [17]:
tok_config = {
    "use_pitchdrum_tokens": False
}

tok_config = TokenizerConfig(**tok_config)
tokenizer = REMI(tok_config)

### (Optional): train the tokenizer

In [18]:
tokenizer.train(vocab_size=1000, files_paths=processed)

### Tokenizer the dataset

In [19]:
def midi_valid(midi) -> bool:
    if any(ts.numerator != 4 for ts in midi.time_signature_changes):
        return False  # time signature different from 4/*, 4 beats per bar
    return True

if os.path.exists("tokenized"):
  shutil.rmtree("tokenized")

tokenizer.tokenize_dataset(        # 2 velocity and 1 duration values
    Path("/content", "pre-processed"),
    Path("/content", "tokenized"),
    midi_valid,
)

Tokenizing music files (content/tokenized): 100%|██████████| 493/493 [00:40<00:00, 12.19it/s]


### Utility function to read a JSON tokenized file

In [20]:
def read_json(path: str) -> dict:
  with open(path, "r") as f:
    return json.load(f)

### See the tokenized version of the previous file

In [21]:
if not skip:
  tokenized_example = os.path.join("tokenized", Path(example).stem + ".json")
  example_ids = read_json(tokenized_example)["ids"][0]
  print(f"Showing IDS of {tokenized_example}")
  print(np.array(example_ids))

## Read the tokenized version of files from the JSON

In [22]:
def read_json_files(json_file_paths):
    """Reads a list of JSON files and returns a list of objects.

    Args:
        json_file_paths: A list of file paths to JSON files.

    Returns:
        A list of objects, where each object represents the data from a JSON file.
        Returns an empty list if any error occurs during file processing.
    """
    objects = []
    for file_path in tqdm(json_file_paths):
        try:
            objects.append(read_json(file_path))
        except FileNotFoundError:
            print(f"Error: File not found - {file_path}")
            return [] # Return empty list on error
        except json.JSONDecodeError:
            print(f"Error decoding JSON in file: {file_path}")
            return [] # Return empty list on error
    return objects

# Example usage (assuming 'tokenized' directory contains JSON files):
tokenized_files = list(Path("/content", "tokenized").resolve().glob("**/*.json"))
data_objects = read_json_files(tokenized_files)

if data_objects:
    print(f"\nSuccessfully read {len(data_objects)} JSON files.")
    # Now you can work with the 'data_objects' list
    # For example, print the first object:
    # print(data_objects[0])
else:
    print("Error reading JSON files.")

100%|██████████| 493/493 [00:00<00:00, 768.20it/s]


Successfully read 493 JSON files.





## Create the list of tokenized songs, taking the IDs of each one

In [23]:
encoded = [np.array(song["ids"][0]) for song in data_objects]

### Listen the same example from before, decoded from its tokenization

In [24]:
if not skip:
  print(f"Showing decoded IDS of {tokenized_example}")

  to_play = play_tokens(example_ids, tokenizer, show_info = True)
  to_play

## Trim of initial and ending silence in each song

In [25]:
def trim(ids: np.ndarray, token_to_remove: int) -> np.ndarray:
  """
  Returns a new numpy array with initial and ending `token_to_remove` removed.
  """
  start_idx = 0
  end_idx = len(ids)

  while start_idx < len(ids) and ids[start_idx] == token_to_remove:
      start_idx += 1
  while end_idx > start_idx and ids[end_idx - 1] == token_to_remove:
      end_idx -= 1
  return ids[start_idx:end_idx]

bar_token = tokenizer.vocab["Bar_None"]
encoded = [trim(ids, bar_token) for ids in encoded]

## Remove long rests in the song

In [26]:
def shorten_sequences(arr: np.ndarray, target_value: int, max_length: int) -> np.ndarray:
  result = []
  current_sequence = []

  for value in arr:
    if value == target_value:
      current_sequence.append(value)
    else:
      if len(current_sequence) > max_length:
        result.extend([target_value] * max_length)
      else:
        result.extend(current_sequence)

      current_sequence = []
      result.append(value)

  if len(current_sequence) > max_length:
    result.extend([target_value] * max_length)
  else:
    result.extend(current_sequence)

  return np.array(result)

bar_token = tokenizer.vocab["Bar_None"]
max_rest_length = 5
encoded = [shorten_sequences(ids, bar_token, max_rest_length) for ids in encoded]

### Show an example from the resulting IDs

In [27]:
if not skip:
  random_ids = random.choice(encoded)[:1000]

  print("Decoded:")
  play_tokens(random_ids, tokenizer, show_info = True)

End of pre-processing, proceding with data and model preparation with Tensorflow

---



# Tensorflow data and model setup

## Creating a Tensorflow dataset with all IDs

In [28]:
import tensorflow as tf
import numpy as np

all_ids = np.concatenate(encoded)
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)

### Convert into sequences

In [29]:
seq_length = 400 #@param {type: 'slider', max: 500, min: 50, step: 50}
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

## Preparing labels

In [30]:
def split_input_target(sequence):
    input_seq = sequence[:-1]
    target_seq = sequence[1:]
    return input_seq, target_seq

dataset = sequences.map(split_input_target)

### Creating training batches

In [31]:
# Batch size
BATCH_SIZE = 32 #@param {type: 'slider', max: 256, min: 32, step: 32}

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

dataset

<_PrefetchDataset element_spec=(TensorSpec(shape=(32, 400), dtype=tf.float64, name=None), TensorSpec(shape=(32, 400), dtype=tf.float64, name=None))>

## Splitting in Train, Validation and Test

In [32]:
ds_size = dataset.cardinality().numpy()
train_size = int(0.8 * ds_size)
val_size = int(0.1 * ds_size)
test_size = int(0.1 * ds_size)

In [33]:
train_ds = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid_ds = remaining.take(val_size)
test_ds = remaining.skip(val_size)

# Building the model

## Transformer

## Keras NLP

In [34]:
def create_model(seq_length,
                 vocab_size,
                 model_dim=256,
                 norm_epsilon=1e-5,
                 dropout=0.1,
                 num_layers=3,
                 intermediate_dim=512,
                 num_heads=4
                 ):
  inputs = keras.Input(shape=(seq_length,))

  # Embed our tokens with a positional embedding.
  embedding_layer = nlp_layers.TokenAndPositionEmbedding(
      vocabulary_size=vocab_size,
      sequence_length=seq_length,
      embedding_dim=model_dim,
  )
  outputs = embedding_layer(inputs)

  # Apply layer normalization and dropout to the embedding.
  outputs = keras.layers.LayerNormalization(epsilon=norm_epsilon)(outputs)
  outputs = keras.layers.Dropout(rate=dropout)(outputs)

  # Add a number of encoder blocks
  for i in range(num_layers):
      outputs = nlp_layers.TransformerEncoder(
          intermediate_dim=intermediate_dim,
          num_heads=num_heads,
          dropout=dropout,
          layer_norm_epsilon=norm_epsilon,
      )(outputs)

  outputs = keras.layers.Dense(units=vocab_size)(outputs)

  return keras.Model(inputs, outputs)

In [35]:
model = create_model(seq_length, tokenizer.vocab_size)
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
LEARNING_RATE = 5e-4

model.compile(loss=loss,
              optimizer=keras.optimizers.AdamW(LEARNING_RATE),
              weighted_metrics=["sparse_categorical_accuracy"],
              jit_compile=True,
              )

In [36]:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True
)

In [37]:
EPOCHS = 10

history = model.fit(train_ds,
                    epochs=EPOCHS,
                    validation_data=valid_ds,
                    callbacks=[checkpoint_callback, early_stopping]
                    )

Epoch 1/10


FailedPreconditionError: Graph execution error:

Detected at node StatefulPartitionedCall defined at (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code

  File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>

  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start

  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start

  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>

  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request

  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "<ipython-input-37-c0dcf454099a>", line 3, in <cell line: 3>

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 320, in fit

  File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 121, in one_step_on_iterator

DNN library initialization failed. Look at the errors above for more details.
	 [[{{node StatefulPartitionedCall}}]] [Op:__inference_one_step_on_iterator_11533]

## Generation

In [None]:
def generate_one_step(model, seed):
  predictions = model(seed)
  sampled_indices = tf.random.categorical(predictions[0], num_samples=1)
  sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()
  return sampled_indices

for seed_ids, _ in test_ds.take(1):
  seed = seed_ids
seed = seed[0]

In [None]:
print(seed)

In [None]:
midi = tokenizer.decode([seed])
midi.dump_midi("nesgen.mid")
playMidi("nesgen.mid")

In [None]:
def generate(seed, model, length = 5):
  import time
  start = time.time()
  next_ids = seed
  result = [next_ids]
  n = 0
  for n in tqdm(range(length)):
    next_ids = generate_one_step(model, tf.expand_dims(next_ids, axis=0))
    result.append(next_ids)

  end = time.time()
  print('\nRun time:', end - start, "\n", '_'*80, "\n")
  result = np.concatenate(result[1:])
  print("Shape of result: ", result.shape)
  return result

In [None]:
result = generate(seed, model, 5)

In [None]:
midi = tokenizer.decode([result])
midi.dump_midi("nesgen.mid")
playMidi("nesgen.mid")

# OLD

## Hyper-parameters

In [None]:
num_layers = 4
d_model = 128
dff = 256
num_heads = 2
dropout_rate = 0.1

In [None]:
from transformer import Transformer

model = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=tokenizer.vocab_size,
    target_vocab_size=tokenizer.vocab_size,
    dropout_rate=dropout_rate
    )

In [None]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam',loss=loss, metrics=['accuracy'])

---

## Try generating with un-trained model

In [None]:
from utility import playMidi

# Use the model
for input_example_batch, target_example_batch in dataset.take(1):
  example_batch_predictions = model(input_example_batch)
  print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

# Sample indices from predictions
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

# Decode
midi = tokenizer.decode([sampled_indices])
midi.dump_midi("boh5.mid")
playMidi("boh5.mid")

In [None]:
# Calculate the loss
#loss = tf.keras.losses.sparse_categorical_crossentropy(target_example_batch, example_batch_predictions, from_logits=True)
# Reduce mean to get a single scalar loss value
#loss = tf.reduce_mean(loss)

#print("Loss:", loss.numpy())

## Taken from tensorflow tutorial:

A newly initialized model shouldn't be too sure of itself, the output logits should all have similar magnitudes. To confirm this you can check that the exponential of the mean loss is approximately equal to the vocabulary size. A much higher loss means the model is sure of its wrong answers, and is badly initialized:

In [None]:
#print("Checking if it is near to vocabulary size")
#print(tf.exp(example_batch_mean_loss).numpy())
#print("Vocab size: ", encoding.vocab_size)

# Training

In [None]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])

In [None]:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True
)

In [None]:
EPOCHS = 5

history = model.fit(train_ds,
                    epochs=EPOCHS,
                    validation_data=valid_ds,
                    callbacks=[checkpoint_callback, early_stopping]
                    )

# Generation

In [None]:
class OneStep(tf.keras.Model):
  def __init__(self, model, decoding, encoding, vocab_size, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.decode = decoding
    self.encode = encoding

    # Taken from tensorflow tutorial: useful to skip ids

    #skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    #sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
    #    values=[-float('inf')]*len(skip_ids),
    #    indices=skip_ids,
        # Match the shape to the vocabulary
    #    dense_shape=[vocab_size])
    #self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, input_ids):
    input_ids_ = tf.expand_dims(input_ids, axis=0)

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits = self.model(inputs=input_ids_)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature

    # Taken from tensorflow tutorial: apply prediction mask to prevent certain
    # ids from being generated
    #predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Return the ids and model state.

    return predicted_ids

In [None]:
one_step_model = OneStep(model, tokenizer.decode, tokenizer.encode, tokenizer.vocab_size)

In [None]:
midis = list(Path("/content/midis").glob("*.mid"))

In [None]:
def get_random_file() -> str:
  return random.choice(midis)

seed_ids = np.array(tokenizer.encode(get_random_file())[0].ids)
seed_ids = seed_ids[:seq_length]
print(seed_ids)

In [None]:
tf.config.run_functions_eagerly(True)
one_step_model.generate_one_step(seed_ids)

In [None]:
import time

start = time.time()
next_ids = seed_ids
result = [next_ids]
n = 0
while n < 1000:
  next_ids = one_step_model.generate_one_step(next_ids)
  if next_ids == tokenizer.vocab["Bar_None"]:
    continue
  n = n + 1
  result.append(next_ids)

end = time.time()
print('\nRun time:', end - start, "\n", '_'*80, "\n")
result = np.concatenate(result[1:])
print("Shape of result: ", result.shape)

## Save the generator

In [None]:
tf.saved_model.save(one_step_model, 'one_step')

## Reload the generator

In [None]:
one_step_reloaded = tf.saved_model.load('one_step')

In [None]:
def tokens_from_ids(ids, tokenizer):
  tokens = []
  for id in ids:
    for key, value in tokenizer.vocab.items():
      if value == id:
        tokens.append(key)
  return np.array(tokens)

In [None]:
example = train_ds.take(1).as_numpy_iterator().next()[0]

In [None]:
tokens_from_ids(example[0], tokenizer)

In [None]:
tokens_from_ids(result, tokenizer)

## Hear the result

In [None]:
print(tokens_from_ids(result, tokenizer))

In [None]:
from utility import playMidi

midi = tokenizer.decode([result])
midi.dump_midi("result.mid")
playMidi("result.mid")