In [1]:
import time
from datetime import datetime
import numpy as np
import tensorflow as tf
from model.vae import cnn_vae_rnn
from util.miditools import piano_roll_to_pretty_midi

##################################################################
# Setting up constants and loading the data
##################################################################

  from ._conv import register_converters as _register_converters


In [3]:
snapshot_interval = 200
log_interval = 50

checkpoint_file = './tfmodel/exp-musedata-bigru-iter-%s-trainloss-%s-c-major.tfmodel'
mudb_file = '../Nottingham/preprocessing/CN_mudb_train.npz'
dev_file = '../Nottingham/preprocessing/CN_mudb_valid.npz'

train_data = np.load(mudb_file)
dev_data = np.load(dev_file)

# print range(train_data['bars'])


In [5]:
fs = train_data['fs']
num_timesteps = int(fs)
bars = train_data['bars']
devBars = dev_data['bars']
np.random.shuffle(bars)

print devBars.shape
print len(bars)

(8901, 256, 128)
34176


In [6]:
note_range = int(devBars.shape[2])


T = int(train_data['T']) #16
num_batches = int(bars.shape[0])

height = num_timesteps #19
width = note_range #128
n_visible = note_range * num_timesteps
n_epochs = 100

z_dim = 350
X_dim = width * height
n_hidden = z_dim
h_dim = z_dim
batch_size = 32

trainBarsBatch = np.reshape(devBars, (-1, T, height, width, 1))
trainBarsBatches = []
i = 0
while i < trainBarsBatch.shape[0] - 32:
    trainBarsBatches.append(trainBarsBatch[i:i+32])
    i += 32
devBarsBatch = np.reshape(devBars, (-1, T, height, width, 1))
devBarsBatches = []
i = 0
while i < devBarsBatch.shape[0] - 32:
    devBarsBatches.append(devBarsBatch[i:i+32])
    i += 32
#devBarsBatch = np.array_split(devBarsBatch, batch_size)
initializer = tf.contrib.layers.xavier_initializer()

audio_sr = 44100

devLoss = True
devInterval = 100

In [None]:
##################################################################
# Loading the model
##################################################################
with tf.name_scope('placeholders'):
    z = tf.placeholder(tf.float32, shape=[None, z_dim], name="Generated_noise")
    #(batch x T x width x height x channels)
    z_rnn_samples = tf.placeholder(tf.float32, shape=[None, T, height, width, 1], name="Generated_midi_input")
    
    X = tf.placeholder(tf.float32, shape=[None, T, height, width, 1], name="Training_samples")
    kl_annealing = tf.placeholder(tf.float32, name="KL_annealing_multiplier")

model = cnn_vae_rnn(X, z, z_rnn_samples, X_dim, z_dim=z_dim, h_dim=h_dim, initializer=initializer, keep_prob=1.0)

X_samples, out_samples, logits = (model['X_samples'], model['out_samples'], model['logits'])
z_mu, z_logvar = (model['z_mu'], model['z_logvar'])

##################################################################
# Losses
##################################################################
with tf.name_scope("Loss"):
    X_labels = tf.reshape(X, [-1, width*height])

    with tf.name_scope("cross_entropy"):
        recon_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=X_labels), 1)
    with tf.name_scope("kl_divergence"):
        kl_loss = kl_annealing * 0.5 * tf.reduce_sum(tf.square(z_mu) + tf.exp(z_logvar) - z_logvar - 1.,1) 
    
    recon_loss = tf.reduce_mean(tf.reshape(recon_loss, [-1, T]), axis=1)

    loss = tf.reduce_mean(recon_loss + kl_loss)

##################################################################
# Optimizer
##################################################################
with tf.name_scope("Optimizer"):
    solver = tf.train.AdamOptimizer()
    grads = solver.compute_gradients(loss)
    grads = [(tf.clip_by_norm(g, clip_norm=1), v) for g, v in grads]
    train_op = solver.apply_gradients(grads)

##################################################################
# Logging
##################################################################
with tf.name_scope("Logging"):
    recon_loss_ph = tf.placeholder(tf.float32)
    kl_loss_ph = tf.placeholder(tf.float32)
    loss_ph = tf.placeholder(tf.float32)
    audio_ph = tf.placeholder(tf.float32)

    tf.summary.scalar("Reconstruction_loss", recon_loss_ph)
    tf.summary.scalar("KL_loss", kl_loss_ph)
    tf.summary.scalar("Loss", loss_ph)
    tf.summary.audio("sample_output", audio_ph, audio_sr)
    log_op = tf.summary.merge_all()

writer = tf.summary.FileWriter('./tb/', graph=tf.get_default_graph())

sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
#sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))

# Run Initialization operations
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()

loss_avg = 0.0
decay = 0.99
min_loss = 100.0
min_dev_loss = 200.0
time0 = time.time()
##################################################################
# Optimization loop
##################################################################
i = 0
for e in range(n_epochs):
    print("%s EPOCH %d %s" % ("".join(10*["="]), e, "".join(10*["="])))
    for batch in trainBarsBatches:
        kl_an = 1.0#min(1.0, (i / 10) / 200.)
        _,loss_out, kl, recon = sess.run([train_op, loss, kl_loss, recon_loss], feed_dict={X: batch, kl_annealing: kl_an})

        if (i % log_interval) == 0:
            loss_avg = decay*loss_avg + (1-decay)*loss_out
            print('\titer = %d, local_loss (cur) = %f, local_loss (avg) = %f, kl = %f'
                % (i, loss_out, loss_avg, np.mean(kl)))
            
            time_spent = time.time() - time0
            print('\n\tTotal time elapsed: %f sec. Average time per batch: %f sec\n' %
                (time_spent, time_spent / (i+1)))
            #Random samples
            z_in = np.random.randn(1, z_dim)
            z_rnn_out = np.zeros((T,height,width,1))
            first = True
            for j in range(T):
                z_rnn_out = np.expand_dims(z_rnn_out, axis=0)
                samples = sess.run(X_samples, feed_dict={z: np.random.randn(1, z_dim), X: z_rnn_out})
                frames = j + 1
                samples = samples.reshape((-1, height, width, 1))
                z_rnn_out = np.concatenate([samples[:frames], np.zeros((T-frames, height, width, 1))])
            samples = samples.reshape((num_timesteps*(T), note_range))
            thresh_S = samples >= 0.5
            
            pm_out = piano_roll_to_pretty_midi(thresh_S.T * 127, fs=fs)
            midi_out = './tb/audio/test002_{0}.mid'.format(datetime.now().strftime("%Y.%m.%d.%H:%M:%S"))
            wav_out = './tb/audio/test002_{0}.wav'.format(datetime.now().strftime("%Y.%m.%d.%H:%M:%S"))
            audio = pm_out.synthesize() 
            audio = audio.reshape((1, len(audio)))
            #Write out logs
            summary = sess.run(log_op, feed_dict={recon_loss_ph: np.mean(recon), kl_loss_ph: np.mean(kl),
                                                 loss_ph: loss_out, audio_ph: audio})
            writer.add_summary(summary, i)
        
        if devLoss and i % devInterval == 0:
            #dls = []
            #for dbatch in devBarsBatches:
            #    dev_loss_out, kl, recon = sess.run([loss, kl_loss, recon_loss], feed_dict={X: dbatch, kl_annealing: kl_an})
            #    dls.append(dev_loss_out)
            #dev_loss_out = sum(dls) / len(dls)
            #print("Dev set loss %.2f" % dev_loss_out)

            if loss_out < min_dev_loss:
                print("Saving checkpoint with train loss %d" % loss_out)
                min_dev_loss = loss_out
                saver.save(sess, checkpoint_file % (i, str(int(loss_out))))
        i += 1

	iter = 0, local_loss (cur) = 1709.529419, local_loss (avg) = 17.095294, kl = 233.321899

	Total time elapsed: 4.123551 sec. Average time per batch: 4.123551 sec

	iter = 50, local_loss (cur) = 238.183105, local_loss (avg) = 19.306172, kl = 5.837939

	Total time elapsed: 156.880856 sec. Average time per batch: 3.076095 sec

	iter = 100, local_loss (cur) = 169.861908, local_loss (avg) = 20.811730, kl = 2.354903

	Total time elapsed: 302.692744 sec. Average time per batch: 2.996958 sec

Saving checkpoint with train loss 169
	iter = 150, local_loss (cur) = 161.750061, local_loss (avg) = 22.221113, kl = 0.021598

	Total time elapsed: 442.439389 sec. Average time per batch: 2.930062 sec

	iter = 200, local_loss (cur) = 159.876190, local_loss (avg) = 23.597664, kl = 0.066941

	Total time elapsed: 578.155188 sec. Average time per batch: 2.876394 sec

Saving checkpoint with train loss 159
	iter = 250, local_loss (cur) = 105.702789, local_loss (avg) = 24.418715, kl = 0.008332

	Total time elaps

  synthesized /= np.abs(synthesized).max()


Saving checkpoint with train loss 56
	iter = 4150, local_loss (cur) = 57.789780, local_loss (avg) = 79.864583, kl = 0.000648

	Total time elapsed: 11326.900548 sec. Average time per batch: 2.728716 sec

	iter = 4200, local_loss (cur) = 68.425217, local_loss (avg) = 79.750189, kl = 0.000406

	Total time elapsed: 11463.729300 sec. Average time per batch: 2.728810 sec

	iter = 4250, local_loss (cur) = 89.036041, local_loss (avg) = 79.843048, kl = 0.001378

	Total time elapsed: 11600.151801 sec. Average time per batch: 2.728805 sec

	iter = 4300, local_loss (cur) = 50.073235, local_loss (avg) = 79.545349, kl = 0.000843

	Total time elapsed: 11737.583509 sec. Average time per batch: 2.729036 sec

Saving checkpoint with train loss 50
	iter = 4350, local_loss (cur) = 62.562973, local_loss (avg) = 79.375526, kl = 0.001677

	Total time elapsed: 11874.792936 sec. Average time per batch: 2.729210 sec

	iter = 4400, local_loss (cur) = 85.216888, local_loss (avg) = 79.433939, kl = 0.000431

	Total 