In [1]:
import time
from datetime import datetime
import numpy as np
import tensorflow as tf
from model.vae import conditional_cnn_vae_rnn
from util.miditools import piano_roll_to_pretty_midi

In [2]:
##################################################################
# Setting up constants and loading the data
##################################################################

snapshot_interval = 200
log_interval = 50

checkpoint_file = './tfmodel/exp-pmde-iter-%s.tfmodel'
mudb_file = '/home/eko/winter2018/Piano-midi.de/preprocessing/mudb_train.npz'

In [3]:
train_data = np.load(mudb_file)

fs = train_data['fs']
num_timesteps = int(fs)
bars = train_data['bars']
np.random.shuffle(bars)
note_range = int(bars.shape[2])
T = int(train_data['T'])
num_batches = int(bars.shape[0])

height = num_timesteps
width = note_range
n_visible = note_range * num_timesteps
n_epochs = 10

z_dim = 500
X_dim = width * height
n_hidden = z_dim
h_dim = z_dim

initializer = tf.contrib.layers.xavier_initializer()

audio_sr = 44100

IndexError: tuple index out of range

In [None]:
##################################################################
# Loading the model
##################################################################
with tf.name_scope('placeholders'):
    z = tf.placeholder(tf.float32, shape=[None, z_dim], name="Generated_noise")
    z_rnn_samples = tf.placeholder(tf.float32, shape=[None, height, width, 1], name="Generated_midi_input")
    
    X = tf.placeholder(tf.float32, shape=[None, height, width, 1], name="Training_samples")
    kl_annealing = tf.placeholder(tf.float32, name="KL_annealing_multiplier")

model = conditional_cnn_vae_rnn(X, z, z_rnn_samples, X_dim, z_dim=z_dim, h_dim=h_dim, initializer=initializer, keep_prob=0.6)

X_samples, out_samples, logits = (model['X_samples'], model['out_samples'], model['logits'])
z_mu, z_logvar = (model['z_mu'], model['z_logvar'])

##################################################################
# Losses
##################################################################
with tf.name_scope("Loss"):
    X_labels = tf.reshape(X, [-1, width*height])

    with tf.name_scope("cross_entropy"):
        recon_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=X_labels), 1)
    with tf.name_scope("kl_divergence"):
        kl_loss = kl_annealing * 0.5 * tf.reduce_sum(tf.square(z_mu) + tf.square(z_logvar) - tf.log(tf.square(z_logvar)) - 1,1) 
    
    loss = tf.reduce_mean(recon_loss + kl_loss)

##################################################################
# Optimizer
##################################################################
with tf.name_scope("Optimizer"):
    solver = tf.train.AdamOptimizer()
    grads = solver.compute_gradients(loss)
    grads = [(tf.clip_by_norm(g, clip_norm=10), v) for g, v in grads]
    train_op = solver.apply_gradients(grads)

##################################################################
# Logging
##################################################################
with tf.name_scope("Logging"):
    recon_loss_ph = tf.placeholder(tf.float32)
    kl_loss_ph = tf.placeholder(tf.float32)
    loss_ph = tf.placeholder(tf.float32)
    audio_ph = tf.placeholder(tf.float32)

    tf.summary.scalar("Reconstruction_loss", recon_loss_ph)
    tf.summary.scalar("KL_loss", kl_loss_ph)
    tf.summary.scalar("Loss", loss_ph)
    tf.summary.audio("sample_output", audio_ph, audio_sr)
    log_op = tf.summary.merge_all()

writer = tf.summary.FileWriter('./tb/', graph=tf.get_default_graph())

sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
#sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))

# Run Initialization operations
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()

loss_avg = 0.0
decay = 0.99
time0 = time.time()
##################################################################
# Optimization loop
##################################################################
for e in range(n_epochs):
    i = 0
    for batch in bars:
        batch_in = np.reshape(batch, (T, height, width, 1))
        kl_an = 1.0#min(1.0, ((e+1)*num_batches) / (2*num_batches))
        _,loss_out, kl, recon = sess.run([train_op, loss, kl_loss, recon_loss], feed_dict={X: batch_in, kl_annealing: kl_an})

        if (i % log_interval) == 0:
            loss_avg = decay*loss_avg + (1-decay)*loss_out
            print('\titer = %d, local_loss (cur) = %f, local_loss (avg) = %f, kl = %f'
                % (i, loss_out, loss_avg, kl))
            
            time_spent = time.time() - time0
            print('\n\tTotal time elapsed: %f sec. Average time per batch: %f sec\n' %
                (time_spent, time_spent / (i+1)))
            #Random samples
            z_in = np.random.randn(1, z_dim)
            z_rnn_out = np.zeros((1,height,width,1))
            first = True
            for j in range(T):
                samples = sess.run(X_samples, feed_dict={z: np.random.randn(1, z_dim), z_rnn_samples: z_rnn_out})
                frames = j + 1
                if first:
                    frames = 2
                z_rnn_out = samples.reshape((frames, height, width, 1))
                if first:
                    z_rnn_out = np.expand_dims(z_rnn_out[-1,:,:,:], axis=0)
                    first = False
            samples = samples.reshape((num_timesteps*T, note_range))
            thresh_S = samples >= 0.5
            
            pm_out = piano_roll_to_pretty_midi(thresh_S.T * 127, fs=fs)
            midi_out = './tb/audio/test002_{0}.mid'.format(datetime.now().strftime("%Y.%m.%d.%H:%M:%S"))
            wav_out = './tb/audio/test002_{0}.wav'.format(datetime.now().strftime("%Y.%m.%d.%H:%M:%S"))
            audio = pm_out.synthesize() 
            audio = audio.reshape((1, len(audio)))
            #Write out logs
            summary = sess.run(log_op, feed_dict={recon_loss_ph: np.mean(recon), kl_loss_ph: np.mean(kl),
                                                 loss_ph: loss_out, audio_ph: audio})
            writer.add_summary(summary, i)
        
        if (i % snapshot_interval) == 0:
            saver.save(sess, checkpoint_file % i)

        i += 1