In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
import keras
from keras.models import Sequential, Model, load_model
from keras.layers import LSTM, Dense, RepeatVector, TimeDistributed, Bidirectional, Input, BatchNormalization, \
    multiply, concatenate, Flatten, Activation, dot, LeakyReLU
from keras.optimizers import Adam
from keras.utils import plot_model
from keras.callbacks import EarlyStopping
import tensorflow as tf

import pydot as pyd
from keras.utils.vis_utils import plot_model, model_to_dot
keras.utils.vis_utils.pydot = pyd

import wandb
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
data = pkl.load(open('dataset.pkl', 'rb'))

dataset_train_input = data['dataset_train_input']
dataset_train_label = data['dataset_train_label']
dataset_val_input = data['dataset_val_input']
dataset_val_label = data['dataset_val_label']
dataset_test_input = data['dataset_test_input']
dataset_test_label = data['dataset_test_label']

velocity_min = data['velocity_min']
velocity_max = data['velocity_max']


In [None]:
input_train = Input(shape=(dataset_train_input.shape[1], dataset_train_input.shape[2]))
label_train = Input(shape=(dataset_train_label.shape[1], dataset_train_label.shape[2]))

print(input_train)
print(label_train)

In [None]:
N_HIDDEN = 4

wandb.init(
    # set the wandb project where this run will be logged
    project="midi-velocity-infer-v2",

    # track hyperparameters and run metadata with wandb.config
    config={
        "n_hidden": N_HIDDEN,
        "activation_1": "LeakyRelu",
        "dropout": 0.2,
        "optimizer": "adam",
        "loss": "mse_cosine_loss",
        "metric": "mae",
        "epoch": 20,
    }
)
config = wandb.config

In [None]:
encoder_last_h1, encoder_last_h2, encoder_last_c = LSTM(
    config.n_hidden, activation=LeakyReLU(),
    input_shape=(dataset_train_input.shape[1], dataset_train_input.shape[2]), 
    return_sequences=False, return_state=True
)(input_train)
print(encoder_last_h1)
print(encoder_last_h2)
print(encoder_last_c)

In [None]:
encoder_last_h1 = BatchNormalization(momentum=0.9)(encoder_last_h1) 
print(encoder_last_h1)
encoder_last_c = BatchNormalization(momentum=0.9)(encoder_last_c) 
print(encoder_last_c)
decoder = RepeatVector(dataset_train_input.shape[1])(encoder_last_h1)
print(decoder)

In [None]:
decoder = LSTM(
    N_HIDDEN, activation=LeakyReLU(), dropout=config.dropout,
    return_sequences=True, return_state=False
)(decoder, initial_state=[encoder_last_h1, encoder_last_c])
print(decoder)

In [None]:
out = TimeDistributed(Dense(dataset_train_label.shape[2]))(decoder)
print(out)

In [None]:
model = Model(inputs=input_train, outputs=out)
opt = Adam(lr=0.0001)

from keras.losses import mse, cosine_similarity
def make_mse_cosine_loss(alpha):
    def mse_cosine_loss(y_true, y_pred):
        # y_pred = tf.clip_by_value(y_pred, clip_value_min=0, clip_value_max=127)
        return alpha * (1 * cosine_similarity(y_true, y_pred)) + (1 - alpha) * mse(y_true, y_pred)
    return mse_cosine_loss
ALPHA = 0.10
mse_cosine_loss = make_mse_cosine_loss(ALPHA)

def clipped_loss(y_true, y_pred):
    y_pred = tf.clip_by_value(y_pred, clip_value_min=0, clip_value_max=127)
    loss = tf.losses.mean_squared_error(y_true, y_pred)
    return loss

model.compile(loss=mse_cosine_loss, optimizer=opt, metrics=['mae'])
model.summary()

In [None]:
plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

In [None]:
from datetime import datetime
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M_%S')

import os
os.makedirs(f'saved_models', exist_ok=True)

import logging
tf.get_logger().setLevel(logging.ERROR) # TODO: comment if you need to debug

# es = EarlyStopping(monitor='val_loss', patience=10)
# history = model.fit(dataset_train_input, dataset_train_label, epochs=epoch, validation_data=(dataset_val_input, dataset_val_label), callbacks=[es])
history = model.fit(dataset_train_input, dataset_train_label, epochs=config.epoch, validation_data=(dataset_val_input, dataset_val_label)
                    , callbacks=[WandbMetricsLogger(log_freq=5), WandbModelCheckpoint("models")])
wandb.finish()
model.save(f'mvi-v2-{current_time}-h{config.n_hidden}-{config.loss}-no_attention.h5')

In [None]:
train_mae = history.history['mae']
valid_mae = history.history['val_mae']

plt.plot(train_mae, label='train mae'), 
plt.plot(valid_mae, label='validation mae')
plt.ylabel('mae')
plt.xlabel('epoch')
plt.title('train vs. validation accuracy (mae)')
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, shadow=False, ncol=2)
plt.show()