[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timolai-andrievich/philosophy-2-assignment-2-keras-model/blob/main/train.ipynb)

In [None]:
import datetime
import glob
import pathlib
from typing import Tuple

# Training hyperparameters

In [None]:
SEQ_LEN = 8
BATCH_SIZE = 256
WEIGHTED_MSE_ZERO_COEFFICIENT = .05

# Import third-party libraries

In [None]:
# Used in Google Colab
%pip install pretty_midi --quiet

In [None]:
import numpy as np
import pretty_midi
from sklearn.model_selection import train_test_split
import tensorflow as tf
import tqdm

# Download the [Maestro](https://magenta.tensorflow.org/datasets/maestro) dataset

In [None]:
data_dir = pathlib.Path('data/maestro-v2.0.0')
if not data_dir.exists():
  tf.keras.utils.get_file(
      'maestro-v2.0.0-midi.zip',
      origin='https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip',
      extract=True,
      cache_dir='.', cache_subdir='data',
  )

# Load dataset for training

In [None]:
def group_notes(pretty_midi_object: pretty_midi.PrettyMIDI) -> np.ndarray:
    """Group the notes in the midi file by the beats they were played in.

    Args:
        pretty_midi_object (pretty_midi.PrettyMIDI): The pretty midi object
            containing the information about the melody.

    Returns:
        np.ndarray: The numpy array containing the bitmap of the notes
            that were played in a certaing beat of the melody.
    """    
    result = np.zeros((len(pretty_midi_object.get_beats()) - 1, 128), np.int8)
    beats = np.array(pretty_midi_object.get_beats())
    for note in pretty_midi_object.instruments[0].notes:
        start_beat = np.searchsorted(beats, note.start, 'right')
        end_beat = np.searchsorted(beats, note.end, 'left')
        result[start_beat:end_beat, note.pitch] = 1
    return result

Load all the notes from the dataset into one tensor:

In [None]:
notes = np.zeros((0, 128), np.int8)
for file in tqdm.tqdm(glob.glob('data/*/*/*.midi')):
    file_notes = group_notes(pretty_midi.PrettyMIDI(file))
    notes = np.concatenate([notes, file_notes])
notes_in_tensor = tf.convert_to_tensor(notes, tf.float16)

In [None]:
def load_index(idx: int) -> Tuple[tf.Tensor, tf.Tensor]:
    return notes_in_tensor[idx:idx+SEQ_LEN], notes_in_tensor[idx+SEQ_LEN]

Construct training and test datasets:

In [None]:
indexes = list(range(notes.shape[0] - SEQ_LEN - 1))
train_idx, test_idx = train_test_split(indexes, test_size=.1, random_state=42)
train_ds = tf.data.Dataset.from_tensor_slices(train_idx).map(load_index).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_ds = tf.data.Dataset.from_tensor_slices(test_idx).map(load_index).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# Creating and training the model

The `weighted_mse` loss function: 
$$ L(y, \hat{y}) = \lambda_{0}\mathop{\mathbb{E}}\limits_{i=0}^{128}(y_i - \hat{y}_i)^2 \mathbb{C}_i^0 + \\ 
\mathop{\mathbb{E}}\limits_{i=0}^{128}(y_i - \hat{y}_i)^2 \mathbb{C}_i^1 $$
Where $\mathbb{C}^0_i$ denotes if there is no note with pitch $i$ in the beat, and $\mathbb{C}^1_i$ denotes if there is a note with pitch $i$ in the beat, $y$ denote real values, $\hat{y}$ denote predicted values.

In [None]:
@tf.function
def weighted_mse(y_true, y_pred):
    loss_1 = (y_true - y_pred) ** 2 * y_true
    loss_0 = (y_true - y_pred) ** 2 * (1 - y_true)
    return tf.reduce_mean(loss_0 * WEIGHTED_MSE_ZERO_COEFFICIENT + loss_1, axis=-1)

## Topology of the model

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.LSTM(256),
    tf.keras.layers.Dense(128, activation='sigmoid')
])
model.compile(optimizer='adam', loss=weighted_mse)

## Training and evaluating

In [None]:
model.fit(train_ds, epochs=3)

In [None]:
model.evaluate(test_ds)

In [None]:
model.save('net.h5')