In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Main script used for training."""
from tensorflow.keras.callbacks import TensorBoard, CSVLogger
from tensorflow.keras.models import load_model
import keras.metrics
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
from nmp import model as mod
from nmp import dataset, ev_metrics
from nmp.dataset import pyplot_piano_roll
from nmp import plotter
from pathlib import Path
import time
import math
import pypianoroll
from pypianoroll import Multitrack, Track
import numpy as np
import random
import copy

# P = Path(__file__).parent.absolute()
P = Path(os.path.abspath(''))  # Compatible with Jupyter Notebook
P2 = Path('S:\datasets')  # Dataset path

PLOTS = P / 'plots'  # Plots path
FS = 24  # Sampling frequency. 10 Hz = 100 ms
Q = 0  # Quantize?
st = 1  # Past timesteps
num_ts = 1  # Predicted timesteps
DOWN = 12  # Downsampling factor
# D = "data/Piano-midi.de"  # Dataset
D = "data/Nottingham"  # Dataset
# D = "data/MuseData"  # Dataset

LOW_LIM = 33  # A1
HIGH_LIM = 97  # C7

# LOW_LIM = 36  # C2
# HIGH_LIM = 85  # C6

# Complete 88-key keyboard
# LOW_LIM = 21  # A0
# HIGH_LIM = 109  # C8

NUM_NOTES = HIGH_LIM - LOW_LIM
CROP = [LOW_LIM, HIGH_LIM]  # Crop plots

LOAD = 0
TRANS = 0

### Generate list of MIDI files

In [None]:
train_list = [x for x in os.listdir(P / D / 'train') if x.endswith('.mid')]
validation_list = [x for x in os.listdir(P / D / 'valid') if x.endswith('.mid')]
test_list = [x for x in os.listdir(P / D / 'test') if x.endswith('.mid')]

print("\nTrain list:  ", train_list)
print("\nValidation list:  ", validation_list)
print("\nTest list:  ", test_list)

## Datasets

### Generate data from lists
Training, validation and test sets.

In [None]:
start = time.time()

train = dataset.Dataset(train_list, P / D / 'train',  fs=FS, bl=0, quant=Q)
validation = dataset.Dataset(validation_list, P / D / 'valid',  fs=FS, bl=0, quant=Q)
test = dataset.Dataset(test_list, P / D / 'test',  fs=FS, bl=0, quant=Q)

train.build_rnn_dataset("training", down=DOWN, low_lim=LOW_LIM, high_lim=HIGH_LIM)
validation.build_rnn_dataset("validation", down=DOWN, low_lim=LOW_LIM, high_lim=HIGH_LIM)
test.build_rnn_dataset("test", down=DOWN, low_lim=LOW_LIM, high_lim=HIGH_LIM)

end = time.time()
print("Done")
print("Loading time: %.2f" % (end-start))

In [None]:
print(train.dataset[0].shape)
print(train.dataset[1].shape)
print(validation.dataset[0].shape)
print(validation.dataset[1].shape)
print(test.dataset[0].shape)
print(test.dataset[1].shape)

pyplot_piano_roll(test.dataset[1][:, :NUM_NOTES], cmap="Oranges",
                  low_lim=LOW_LIM, high_lim=HIGH_LIM)
plt.title("Test target")
# plt.ylim(CROP)

In [None]:
seq_length = 100
import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices((train.dataset[0]))
train_sequences = train_dataset.batch(seq_length+1, drop_remainder=True)

valid_dataset = tf.data.Dataset.from_tensor_slices((validation.dataset[0]))
valid_sequences = valid_dataset.batch(seq_length+1, drop_remainder=True)

test_dataset = tf.data.Dataset.from_tensor_slices((test.dataset[0]))
test_sequences = test_dataset.batch(seq_length+1, drop_remainder=True)

In [None]:
def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

def split_input_target_base(chunk):
    input_text = chunk[:-1]
    target_text = chunk[:-1]
    return input_text, target_text

train_data = train_sequences.map(split_input_target)
valid_data = valid_sequences.map(split_input_target)
test_data = test_sequences.map(split_input_target)
baseline_data = test_sequences.map(split_input_target_base)

In [None]:
# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

train_data = train_data.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
valid_data = valid_data.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
test_data = test_data.batch(1, drop_remainder=True)
baseline_data = baseline_data.batch(1, drop_remainder=True)
print(train_data)
print(valid_data)
print(test_data)
print(baseline_data)

### Save dataset

### Piano rolls of training dataset
Input and output piano rolls

In [None]:
# plt.rcParams["figure.figsize"] = (20, 8)
# pyplot_piano_roll(train.dataset[0][:, 0, :],
#                   low_lim=LOW_LIM, high_lim=HIGH_LIM)
# plt.title("Train data")
# plt.ylim(CROP)
# pyplot_piano_roll(train.dataset[1][:, :NUM_NOTES], cmap="Oranges",
#                   low_lim=LOW_LIM, high_lim=HIGH_LIM)
# plt.title("Train target")
# plt.ylim(CROP)

## Keras
### Build the model

In [None]:
BS = BATCH_SIZE  # Batch size
import importlib
importlib.reload(mod)
importlib.reload(dataset)
import time

In [None]:
if LOAD:
    model = load_model(filepath=model_path,
                       custom_objects=None,
                       compile=True)

else:
    model = mod.build_gru_model(NUM_NOTES, BS)
    mod.compile_model(model, 'binary_crossentropy', 'adam',
                      metrics=['accuracy'])

model.summary()

now = datetime.now()

# Save logs
logger = TensorBoard(log_dir=P / 'logs' / now.strftime("%Y%m%d-%H%M%S"),
                     write_graph=True, update_freq='epoch')

csv_logger = CSVLogger(P / 'logs' / (now.strftime("%Y%m%d-%H%M%S") + '-' +
                       str(st) + '-' + str(num_ts) + '.csv'),
                       separator=',', append=False)

### Try the model
Try the model before training

In [None]:
# for input_example_batch, target_example_batch in train_data.take(1):
#     example_batch_predictions = model(tf.cast(input_example_batch, tf.float32))
#     print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

### Checkpoints

In [None]:
# Directory where the checkpoints will be saved
checkpoint_dir = P / 'models/training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

### Fit the model
Define batch size ```BS``` and number of ```epochs```

#### fit generator

In [None]:
# # Fit generator the model.
# BS = 64  # Batch size
# epochs = 20
# start = time.time()
# size_train = math.ceil(train.dataset[0].shape[0] / BS)
# spe_train = size_train
# size_valid = math.ceil(validation.dataset[0].shape[0] / BS)
# spe_valid = size_valid
# print("Train dataset shape: ", train.dataset[0].shape, "\n")
# print("Train dataset target shape: ", train.dataset[1].shape, "\n")

# # Fit generator. Data should be shuffled before fitting.
# history = model.fit(dataset.generate((train.dataset[0], train.dataset[1]), trans=1), epochs=epochs,
#           steps_per_epoch=spe_train,
#           validation_data=dataset.generate((validation.dataset[0], validation.dataset[1])),
#           validation_steps=spe_valid,
#           callbacks=[logger, csv_logger])

# end = time.time()

#### fit

In [None]:
# Fit the model.
BS = BATCH_SIZE  # Batch size
epochs = 500
start = time.time()

# Normal fit. Auto-shuffles data.
history = model.fit(train_data, validation_data=valid_data, epochs=epochs, shuffle=True,
                    callbacks=[logger, csv_logger, checkpoint_callback])

end = time.time()

### History

In [None]:
print("\nTraining time: ", (end-start), "\n")
hist = pd.DataFrame(history.history)
hist

### Plot loss function of training and validation sets

In [None]:
fig, ax = plt.subplots(constrained_layout=True, figsize=(5, 4))
plt.plot(hist['val_loss'], '-', lw=3, c='tab:orange', label='Validation', ms=8, alpha=0.8)
plt.plot(hist['loss'], '-', lw=1, c='tab:red', label='Train', ms=8, alpha=0.8)
plt.xlabel('Epoch')
# plt.xticks(range(epochs))
plt.legend()
plt.title('Loss: Binary cross-entropy')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# plt.ylim([0.12, 0.16])
fig.savefig(PLOTS / 'rnn.eps', fmt='eps')
print("Training time: ", (end-start))

## Save model to file
Model can be loaded with:
``` python
load_model(filepath=str(folder_path), compile=True)
```

In [None]:
model.save(str(P / 'models' / 'simpleRNN-nottingham') + '.h5', save_format='h5')
# model.save(str(P / 'models' / 'lstm-z-de') + '.h5', save_format='h5')

### Model evaluation

In [None]:
print("Evaluation on train set:")
e_train = model.evaluate(train_data)

print("\nEvaluation on validation set:")
e_valid = model.evaluate(valid_data)

# print("\nEvaluation on test set:")
# e_test = model.evaluate(test_data)

results = {out: e_train[i] for i, out in enumerate(model.metrics_names)}
res = pd.DataFrame(list(results.items()), columns=['metric', 'train'])
res = res.set_index('metric')

results2 = {out: e_valid[i] for i, out in enumerate(model.metrics_names)}
res2 = pd.DataFrame(list(results2.items()), columns=['metric', 'validation'])
res2 = res2.set_index('metric')

# results3 = {out: e_test[i] for i, out in enumerate(model.metrics_names)}
# res3 = pd.DataFrame(list(results3.items()), columns=['metric', 'test'])
# res3 = res3.set_index('metric')
res3 = pd.DataFrame([])

result = pd.concat([res, res2, res3], axis=1, sort=False)
result

## Make predictions
Predictions from test dataset

### Restore last checkpoint

Build again the model and restore the checkpoint with weights to use a different batch size for test.

In [None]:
model = mod.build_gru_model(NUM_NOTES, 1)
mod.compile_model(model, 'binary_crossentropy', 'adam',
                  metrics=['accuracy'])
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))


In [None]:
model.reset_states()
predauc = []
baseauc = []

merged_input = []
merged_output = []
merged_pred = []

for input_batch, label_batch in test_data.take(-1):
    predictions = model(tf.cast(input_batch, tf.float32))

    # print(tf.squeeze(predictions, 0))
    pred = np.array(tf.squeeze(predictions, 0))
    predictions_bin = dataset.ranked_threshold(pred, steps=1, how_many=5)

    inp = np.array(tf.squeeze(input_batch, 0))
    out = np.array(tf.squeeze(label_batch, 0))

#     pyplot_piano_roll(out,
#                       cmap="Greens", low_lim=LOW_LIM, high_lim=HIGH_LIM)
#     plt.title("Target")
#     plt.savefig(PLOTS / "roll1.png")

#     pyplot_piano_roll(pred,
#                       cmap="Purples", low_lim=LOW_LIM, high_lim=HIGH_LIM)
#     plt.title("Predictions")
#     plt.savefig(PLOTS / "roll.png")

#     pyplot_piano_roll(inp,
#                   cmap="Blues", low_lim=LOW_LIM, high_lim=HIGH_LIM)
#     plt.title("Baseline (repetition of the input)")
#     plt.savefig(PLOTS / "roll3.png")

    pred_auc = ev_metrics.compute_auc(out, pred, NUM_NOTES)
    base_auc = ev_metrics.compute_auc(out, inp, NUM_NOTES)
    predauc.append(np.mean(np.mean(pred_auc)))
    baseauc.append(np.mean(np.mean(base_auc)))

    
    # Merged piano rolls to compute overall AUC.
    merged_input.append(inp)
    merged_output.append(out)
    merged_pred.append(pred)

merged_input = np.concatenate([x for x in merged_input])
merged_output = np.concatenate([x for x in merged_output])
merged_pred = np.concatenate([x for x in merged_pred])

pred_auc_merged = ev_metrics.compute_auc(merged_output, merged_pred, NUM_NOTES)
base_auc_merged = ev_metrics.compute_auc(merged_output, merged_input, NUM_NOTES)

print("Pred AUC-ROC (mean of subsets): ", np.mean(predauc))
print("Base AUC-ROC:(mean of subsets): ", np.mean(baseauc))

print("Pred AUC-ROC (global): ", np.mean(np.mean((pred_auc_merged))))
print("Base AUC-ROC (global): ", np.mean(np.mean((base_auc_merged))))


# pyplot_piano_roll(test.dataset[1][:, :NUM_NOTES],
#                   cmap="Greens", low_lim=LOW_LIM, high_lim=HIGH_LIM)
# plt.title("Test target (ground truth)")


plt.rcParams["figure.figsize"] = (13, 4)
pyplot_piano_roll(merged_output,
                  cmap="Greens", low_lim=LOW_LIM, high_lim=HIGH_LIM)
plt.title("Target (labels)")
plt.ylim(CROP)
plt.savefig(PLOTS / "roll1.png")
pyplot_piano_roll(merged_pred[:, :NUM_NOTES],
                  cmap="Blues", low_lim=LOW_LIM, high_lim=HIGH_LIM)
plt.title("Predictions")
plt.ylim(CROP)
plt.savefig(PLOTS / "roll2.png")
pyplot_piano_roll(merged_input[:, :NUM_NOTES],
                  cmap="Reds", low_lim=LOW_LIM, high_lim=HIGH_LIM)
plt.title("Baseline (equal to inputs)")
plt.ylim(CROP)
plt.savefig(PLOTS / "roll3.png")
