In [1]:
%cd /src
!ls
import os
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

/src
LICENSE    deprecated	   music-transformer-train.ipynb  train.py
README.md  dist_train.py   params.py			  utils.py
custom	   generate.py	   preprocess.py
data.py    midi_processor  readme_src
dataset    model.py	   requirements.txt


In [None]:
from model import MusicTransformerDecoder
from custom.layers import *
from custom import callback
import params as par
from tensorflow.python.keras.optimizer_v2.adam import Adam
from data import Data
import utils
import argparse
import datetime
import sys

tf.executing_eagerly()

# set arguments
l_r = 0.0001
batch_size = 1
pickle_dir = '/pickle'
max_seq = 256
epochs = 1
is_reuse = False
load_path = None
save_path = './save'
multi_gpu = True
num_layer = 6


# load data
dataset = Data('/data/classic_piano_preprocessed')
print(dataset)


# load model
learning_rate = callback.CustomSchedule(par.embedding_dim) if l_r is None else l_r
opt = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)


# define model
mt = MusicTransformerDecoder(
            embedding_dim=256,
            vocab_size=par.vocab_size,
            num_layer=num_layer,
            max_seq=max_seq,
            dropout=0.2,
            debug=False, loader_path=load_path)
mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)


# define tensorboard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = 'logs/mt_decoder/'+current_time+'/train'
eval_log_dir = 'logs/mt_decoder/'+current_time+'/eval'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
eval_summary_writer = tf.summary.create_file_writer(eval_log_dir)


# Train Start
idx = 0
for e in range(epochs):
    mt.reset_metrics()
    for b in range(len(dataset.files) // batch_size):
        try:
            batch_x, batch_y = dataset.slide_seq2seq_batch(batch_size, max_seq)
        except:
            continue
        result_metrics = mt.train_on_batch(batch_x, batch_y)
        if b % 100 == 0:
            eval_x, eval_y = dataset.slide_seq2seq_batch(batch_size, max_seq, 'eval')
            eval_result_metrics, weights = mt.evaluate(eval_x, eval_y)
            mt.save(save_path)
            with train_summary_writer.as_default():
                if b == 0:
                    tf.summary.histogram("target_analysis", batch_y, step=e)
                    tf.summary.histogram("source_analysis", batch_x, step=e)

                tf.summary.scalar('loss', result_metrics[0], step=idx)
                tf.summary.scalar('accuracy', result_metrics[1], step=idx)

            with eval_summary_writer.as_default():
                if b == 0:
                    mt.sanity_check(eval_x, eval_y, step=e)

                tf.summary.scalar('loss', eval_result_metrics[0], step=idx)
                tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx)
                for i, weight in enumerate(weights):
                    with tf.name_scope("layer_%d" % i):
                        with tf.name_scope("w"):
                            utils.attention_image_summary(weight, step=idx)
                # for i, weight in enumerate(weights):
                #     with tf.name_scope("layer_%d" % i):
                #         with tf.name_scope("_w0"):
                #             utils.attention_image_summary(weight[0])
                #         with tf.name_scope("_w1"):
                #             utils.attention_image_summary(weight[1])
            idx += 1
            print('\n====================================================')
            print('Epoch/Batch: {}/{}'.format(e, b))
            print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(result_metrics[0], result_metrics[1]))
            print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_result_metrics[0], eval_result_metrics[1]))




<class Data has "329" files>
