In [None]:
import tensorflow as tf
import numpy as np
from utils.gmm import mix_coef,gmm2d,gmm_loss
from utils.model import lstm_model
from utils.batch_generator import batch_generator, dataset_generator
import scipy.stats
import copy
import time
import os


In [None]:
#Model parameters
batch_size = 5
seq_length_model = 600
n_units = 900
lr = 0.001
N_mixtures = 20

In [None]:
############################################ MODEL #############################################################
tf.reset_default_graph()
######################### define contants #####################################
batch_size = 50
seq_length = 600
# number of units of the hidden layer
n_units = 900
# learning rate for ""
lr = 0.001
# parameters for gmm
N_mixtures = 20
tot_mixtures = N_mixtures*6 + 1

#weights and biases of appropriate shape to accomplish above task
out_weights = tf.get_variable("w_y", [n_units, tot_mixtures])
out_bias = tf.get_variable("b_y", [tot_mixtures])

########################### Define placeholders #######################################
#input batch of strokes placeholder
X = tf.placeholder(dtype=tf.float32, shape=[None,seq_length,3])
#input label placeholder
targets = tf.placeholder(dtype=tf.float32, shape=[None,seq_length,3])
sample_stroke = tf.placeholder(dtype=tf.float32, shape=[1,1,3])

########################## Define network ############################################
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_units,state_is_tuple=True)
init_state = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
init_sample = cell.zero_state(batch_size=1, dtype=tf.float32)


# reshape target data
flat_targets = tf.reshape(targets, [-1, 3])
# get position values (x,y) and end of stroke data
split_e, split_x, split_y = tf.split(value=flat_targets, axis=1, num_or_size_splits=3)
    

#split in a list of T time steps where T is the sequence length
#batch_X = tf.unstack(X, axis=1)            
split_outputs, final_state = tf.nn.dynamic_rnn(cell=cell, inputs=X,
                                   initial_state=init_state,
                                   dtype=tf.float32)

#For sampling
output_sample, out_sample = tf.nn.dynamic_rnn(cell=cell, inputs=sample_stroke,
                                   initial_state=init_sample,
                                   dtype=tf.float32)

flat_outputs = tf.reshape(split_outputs,[-1,n_units])
output = tf.matmul(flat_outputs,out_weights) + out_bias
#For sampling:
flat_outputs_sample = tf.reshape(output_sample,[-1,n_units])
output_sample = tf.matmul(flat_outputs_sample,out_weights) + out_bias

################################## Using utils.gmm functions ################################
#get mixture gmm coeff:
op_pi, op_mu_x, op_mu_y, op_std_x, op_std_y, op_rho, op_param_e  = mix_coef(output)
# For sampling
sk_pi, sk_mu_x, sk_mu_y, sk_std_x, sk_std_y, sk_rho, sk_param_e  = mix_coef(output_sample)

#compute loss:
op_loss = gmm_loss(split_e, split_x, split_y,op_pi, op_mu_x, op_mu_y, op_std_x, op_std_y, op_rho, op_param_e )
total_loss = op_loss/(batch_size*seq_length) 

##################################### TRAINING ######################################
parameters = tf.trainable_variables()
optimizer = tf.train.RMSPropOptimizer(learning_rate=lr)
grds = optimizer.compute_gradients(total_loss)
LSTM_grds = [grds[2],grds[3]]
out_grds = [grds[0],grds[1]]
clipped_grds = [(tf.clip_by_value(grad1, -100., 100.), var1) for grad1, var1 in out_grds]+[(tf.clip_by_value(grad2, -10., 10.),
                                                                                            var2) for grad2, var2 in LSTM_grds]               
train_op = optimizer.apply_gradients(clipped_grds)
#train_op = optimizer.minimize(total_loss)
init_op = tf.global_variables_initializer()

#ops that we will restore for sampling
#sample_stroke = tf.identity(sample_stroke,name="op_to_restore")  
out_sample = tf.identity(out_sample, name="op_to_restore")
sk_pi = tf.identity(sk_pi, name="op_to_restore")
sk_mu_x = tf.identity(sk_mu_x, name="op_to_restore")
sk_mu_y = tf.identity(sk_mu_y, name="op_to_restore")
sk_std_x = tf.identity(sk_std_x, name="op_to_restore")
sk_std_y = tf.identity(sk_std_y, name="op_to_restore")
sk_rho = tf.identity(sk_rho, name="op_to_restore")
sk_param_e = tf.identity(sk_param_e, name="op_to_restore")

In [None]:
################# prepare data for training ##########################################
seq_length = 700
inputs = np.load('data/strokes.npy',encoding='bytes')
all_inputs = dataset_generator(inputs,seq_length)

In [None]:
# training and validation set
X_train = np.copy(all_inputs[0:5000])
X_val = np.copy(all_inputs[5000:6000])

In [None]:
num_epochs = 20
num_batches = int(X_train.shape[0]/batch_size)

In [None]:
all_train_losses = []
all_valid_losses = []
time_epochs = []
time_batches = []

In [None]:
######################################################## TRAINING ####################################
sess = tf.Session() 
sess.run(init_op)
saver = tf.train.Saver(tf.global_variables())
for epoch in range(num_epochs):
    start_e = time.time()
    valid_batch, valid_labels = batch_generator(X_val,batch_size,600)
    valid_feed = { X:valid_batch, targets:valid_labels }
    np.random.shuffle(X_train)
    for btch in range(num_batches):
        train_batch,labels_batch = batch_generator(X_train,batch_size,600)
        feed_dict = {X:train_batch, targets:labels_batch}
        start_b = time.time()
        train_loss, out_state, _ = sess.run([total_loss,
                                                final_state,
                                                train_op], 
                                                feed_dict)
        print(train_loss)                                                             
        all_train_losses.append(train_loss)                                                           
            
        valid_loss = sess.run([total_loss], valid_feed)
        all_valid_losses.append(valid_loss)                                         
        #saver.save(sess, './LSTM_900_2xclip.chkp')
        end_b = time.time()
        time_batches.append(end_b-start_b)   
        end_e = time.time()
        time_epochs.append(end_e-start_e)
    saver.save(sess, './LSTM_900_2xclip.chkp') 
    print('end of epoch number',epoch,'valid loss is',valid_loss)   

In [None]:
desired_seq_length = 600
#sess = tf.Session()
sample = np.zeros((desired_seq_length,3))
current_stroke = np.zeros((1,1,3))
previous_stroke = sess.run(cell.zero_state(batch_size=1,dtype=tf.float32))
for k in range(desired_seq_length):
    feed_dict = {sample_stroke:current_stroke, init_sample:previous_stroke}
    [pi0, mu1, mu2, std1, std2, rho0, param_e0,next_stroke] = sess.run([sk_pi, sk_mu_x, sk_mu_y,
                                                                        sk_std_x, sk_std_y,
                                                                        sk_rho, sk_param_e, out_sample],
                                                                        feed_dict)
    #pick a random mixture
    mix = np.random.randint(0,20)
    mu = np.array([mu1[0,mix],mu2[0,mix]])
    sigma = np.array([[std1[0,mix]*std1[0,mix],rho0[0,mix]*std1[0,mix]*std2[0,mix]],
                      [rho0[0,mix]*std1[0,mix]*std2[0,mix],std2[0,mix]*std2[0,mix]]])
    #generate data
    z = np.random.multivariate_normal(mu, sigma, 1)
    z_e = scipy.stats.bernoulli.rvs(param_e0, size=1)
    sample[k] = [z_e, z[0,0], z[0,1]]
    
    current_stroke = np.zeros((1,1,3))
    current_stroke[0][0] = [z_e, z[0,0], z[0,1]]
    previous_stroke = next_stroke
