In [47]:
import numpy as np
import numpy.matlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

import math
import random
import time
import os
import cPickle as pickle

import tensorflow as tf
    

In [48]:
# in the real project class, we use argparse (https://docs.python.org/3/library/argparse.html)
class FakeArgParse():
    def __init__(self):
        pass
args = FakeArgParse()

#general model params
args.train = False
args.rnn_size = 100 #400 hidden units
args.tsteps = 256 if args.train else 1
args.batch_size = 32 if args.train else 1
args.nmixtures = 8 # number of Gaussian mixtures in MDN

#window params
args.kmixtures = 1 # number of Gaussian mixtures in attention mechanism (for soft convolution window)
args.alphabet = ' abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' #later we'll add an <UNK> slot for unknown chars
args.tsteps_per_ascii = 25 # an approximate estimate

#book-keeping
args.save_path = './saved/model.ckpt'
args.data_dir = './data'
args.log_dir = './logs/'
args.text = 'call me ishmael some years ago'
args.style = -1 # don't use a custom style
args.bias = 1.0
args.eos_prob = 0.4 # threshold probability for ending a stroke

In [49]:
# in real life the model is a class. I used this hack to make the iPython notebook more readable
class FakeModel():
    def __init__(self):
        pass
model = FakeModel()

In [50]:
model.char_vec_len = len(args.alphabet) + 1 #plus one for <UNK> token
model.ascii_steps = len(args.text)

model.graves_initializer = tf.truncated_normal_initializer(mean=0., stddev=.075, seed=None, dtype=tf.float32)
model.window_b_initializer = tf.truncated_normal_initializer(mean=-3.0, stddev=.25, seed=None, dtype=tf.float32)

# ----- build the basic recurrent network architecture
cell_func = tf.contrib.rnn.LSTMCell # could be GRUCell or RNNCell
model.cell0 = cell_func(args.rnn_size, state_is_tuple=True, initializer=model.graves_initializer)
model.cell1 = cell_func(args.rnn_size, state_is_tuple=True, initializer=model.graves_initializer)
model.cell2 = cell_func(args.rnn_size, state_is_tuple=True, initializer=model.graves_initializer)

model.input_data = tf.placeholder(dtype=tf.float32, shape=[None, args.tsteps, 3])
model.target_data = tf.placeholder(dtype=tf.float32, shape=[None, args.tsteps, 3])
model.istate_cell0 = model.cell0.zero_state(batch_size=args.batch_size, dtype=tf.float32)
model.istate_cell1 = model.cell1.zero_state(batch_size=args.batch_size, dtype=tf.float32)
model.istate_cell2 = model.cell2.zero_state(batch_size=args.batch_size, dtype=tf.float32)

#slice the input volume into separate vols for each tstep
inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(model.input_data, args.tsteps, 1)]

#build model.cell0 computational graph
outs_cell0, model.fstate_cell0 = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, model.istate_cell0, \
                                                       model.cell0, loop_function=None, scope='cell0')

In [51]:
# ----- build the gaussian character window
def get_window(alpha, beta, kappa, c):
    # phi -> [? x 1 x ascii_steps] and is a tf matrix
    # c -> [? x ascii_steps x alphabet] and is a tf matrix
    ascii_steps = c.get_shape()[1].value #number of items in sequence
    phi = get_phi(ascii_steps, alpha, beta, kappa)
    window = tf.matmul(phi,c)
    window = tf.squeeze(window, [1]) # window ~ [?,alphabet]
    return window, phi

#get phi for all t,u (returns a [1 x tsteps] matrix) that defines the window
def get_phi(ascii_steps, alpha, beta, kappa):
    # alpha, beta, kappa -> [?,kmixtures,1] and each is a tf variable
    u = np.linspace(0,ascii_steps-1,ascii_steps) # weight all the U items in the sequence
    kappa_term = tf.square( tf.subtract(kappa,u))
    exp_term = tf.multiply(-beta,kappa_term)
    phi_k = tf.multiply(alpha, tf.exp(exp_term))
    phi = tf.reduce_sum(phi_k,1, keep_dims=True)
    return phi # phi ~ [?,1,ascii_steps]

def get_window_params(i, out_cell0, kmixtures, prev_kappa, reuse=True):
    hidden = out_cell0.get_shape()[1]
    n_out = 3*kmixtures
    with tf.variable_scope('window',reuse=reuse):
        window_w = tf.get_variable("window_w", [hidden, n_out], initializer=model.graves_initializer)
        window_b = tf.get_variable("window_b", [n_out], initializer=model.window_b_initializer)
    abk_hats = tf.nn.xw_plus_b(out_cell0, window_w, window_b) # abk_hats ~ [?,n_out] = "alpha, beta, kappa hats"
    abk = tf.exp(tf.reshape(abk_hats, [-1, 3*kmixtures,1]))

    alpha, beta, kappa = tf.split(abk, 3, 1) # alpha_hat, etc ~ [?,kmixtures]
    kappa = kappa + prev_kappa
    return alpha, beta, kappa # each ~ [?,kmixtures,1]

In [52]:
model.init_kappa = tf.placeholder(dtype=tf.float32, shape=[None, args.kmixtures, 1]) 
model.char_seq = tf.placeholder(dtype=tf.float32, shape=[None, model.ascii_steps, model.char_vec_len])
wavg_prev_kappa = model.init_kappa
prev_window = model.char_seq[:,0,:]

#add gaussian window result
reuse = False
for i in range(len(outs_cell0)):
    [alpha, beta, new_kappa] = get_window_params(i, outs_cell0[i], args.kmixtures, wavg_prev_kappa, reuse=reuse)
    window, phi = get_window(alpha, beta, new_kappa, model.char_seq)
    outs_cell0[i] = tf.concat((outs_cell0[i],window), 1) #concat outputs
    outs_cell0[i] = tf.concat((outs_cell0[i],inputs[i]), 1) #concat input data
#         prev_kappa = new_kappa #tf.ones_like(new_kappa, dtype=tf.float32, name="prev_kappa_ones") #
    wavg_prev_kappa = tf.reduce_mean( new_kappa, reduction_indices=1, keep_dims=True) # mean along kmixtures dimension
    reuse = True
model.window = window #save the last window (for generation)
model.phi = phi #save the last window (for generation)
model.new_kappa = new_kappa #save the last window (for generation)
model.alpha = alpha #save the last window (for generation)
model.wavg_prev_kappa = wavg_prev_kappa

In [54]:
# ----- finish building second recurrent cell
outs_cell1, model.fstate_cell1 = tf.contrib.legacy_seq2seq.rnn_decoder(outs_cell0, model.istate_cell1, model.cell1, \
                                                    loop_function=None, scope='cell1') #use scope from training

# ----- finish building third recurrent cell
outs_cell2, model.fstate_cell2 = tf.contrib.legacy_seq2seq.rnn_decoder(outs_cell1, model.istate_cell2, model.cell2, \
                                                    loop_function=None, scope='cell2')

out_cell2 = tf.reshape(tf.concat(outs_cell2, 1), [-1, args.rnn_size]) #concat outputs for efficiency

In [55]:
#put a dense cap on top of the rnn cells (to interface with the mixture density network)
n_out = 1 + args.nmixtures * 6 # params = end_of_stroke + 6 parameters per Gaussian
with tf.variable_scope('mdn_dense'):
    output_w = tf.get_variable("output_w", [args.rnn_size, n_out], initializer=model.graves_initializer)
    output_b = tf.get_variable("output_b", [n_out], initializer=model.graves_initializer)

output = tf.nn.xw_plus_b(out_cell2, output_w, output_b) #data flows through dense nn