In [1]:
import sys
sys.path.insert(0, "/work/04233/sw33286/AIDA-SCRIPTS")

In [2]:
import random
import numpy as np

import tensorflow as tf
import tensorflow.contrib.seq2seq as seq2seq
from tensorflow.contrib.layers import safe_embedding_lookup_sparse as embedding_lookup_unique
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple, GRUCell

from helpers import Indexer, batch
from collections import defaultdict, Counter

In [3]:
VOCAB = ['one','two','three','four','five',
         'six','seven','eight','nine','ten']

# VOCAB = ['cat','dog','pig','horse','deer',
#          'car','bike','motorcycle','train','bus',
#          'hill','mountain','lake','river','valley',
#          'stool','table','closet','cabinet','bed',
#          'apple','pear','strawberry','grape','tomato']

indexer = Indexer()
indexer.get_index('PAD')
indexer.get_index('EOS')
for word in VOCAB:
    indexer.get_index(word)

In [4]:
FROM_LEN = 5
TO_LEN = 15
MAX_LEN = TO_LEN

def generate_datum(from_len=FROM_LEN, to_len=TO_LEN, pad=False):
    length = np.random.randint(from_len, to_len)
    code = [indexer.get_index(np.random.choice(VOCAB)) for _ in range(length)]
    if pad and length < MAX_LEN:
        code += [indexer.get_index('PAD')] * (MAX_LEN-length)
    return code

def to_sent(code):
    return [indexer.get_object(idx) for idx in code]

def get_batch(n, from_len=FROM_LEN, to_len=TO_LEN):
    return [generate_datum(pad=False) for _ in range(n)]

In [10]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

batch_size = 10
vocab_size = len(indexer)
embed_size = 20
encoder_hidden_size = 10
decoder_hidden_size = encoder_hidden_size*2
latent_size = 20

encoder_cell = LSTMCell(encoder_hidden_size)
decoder_cell = LSTMCell(decoder_hidden_size)

encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs')
encoder_inputs_length = tf.placeholder(tf.int32, [None,], name='encoder_inputs_length')
decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets')
decoder_targets_length = tf.placeholder(tf.int32, [None,], name='decoder_targets_length')

with tf.name_scope('DecoderTrainFeeds'):
    sequence_size, batch_size_ = tf.unstack(tf.shape(decoder_targets))
    EOS_SLICE = tf.ones([1, batch_size_], dtype=tf.int32) * indexer.get_index('EOS')
    PAD_SLICE = tf.ones([1, batch_size_], dtype=tf.int32) * indexer.get_index('PAD')
    decoder_train_inputs = tf.concat([EOS_SLICE, decoder_targets], axis=0) # [max-time+1, batch_size]
    decoder_train_length = decoder_targets_length + 1
    decoder_train_targets = tf.concat([decoder_targets, PAD_SLICE], axis=0)
    decoder_train_targets_seq_len,_ = tf.unstack(tf.shape(decoder_train_targets))
    decoder_train_targets_eos_mask = tf.one_hot(decoder_train_length-1,
                                                decoder_train_targets_seq_len,
                                                on_value=indexer.get_index('EOS'), 
                                                off_value=indexer.get_index('PAD'),
                                                dtype=tf.int32)
    decoder_train_targets_eos_mask = tf.transpose(decoder_train_targets_eos_mask, [1,0])
    decoder_train_targets = tf.add(decoder_train_targets,
                                   decoder_train_targets_eos_mask)

    loss_weights = tf.ones([
        batch_size,
        tf.reduce_max(decoder_train_length)
    ], dtype=tf.float32, name='loss_weights')
        # weights on predictions, usually set as uniform, unless otherwise is needed.

with tf.variable_scope('embedding') as scope:
    embedding_matrix = tf.get_variable('embedding_matrix', [vocab_size, embed_size], 
                                       initializer=tf.contrib.layers.xavier_initializer())
        # to use GloVe, do the following:
        # glove_feed = tf.placeholder(tf.float32, glove_embs.shape)
        # glove_init = embedding_matrix.assign(glove_feed)
    encoder_inputs_embedded = tf.nn.embedding_lookup(embedding_matrix, encoder_inputs)
    decoder_train_inputs_embedded = tf.nn.embedding_lookup(embedding_matrix, decoder_train_inputs)
    
with tf.variable_scope('BidirectionalEncoder') as scope:
    encoder_cell = LSTMCell(encoder_hidden_size)
    ((encoder_fw_outputs,encoder_bw_outputs),
     (encoder_fw_state,encoder_bw_state)) = (
        tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                        cell_bw=encoder_cell,
                                        inputs=encoder_inputs_embedded,
                                        sequence_length=encoder_inputs_length,
                                        dtype=tf.float32, time_major=True)
    )
    encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)
    if isinstance(encoder_fw_state, LSTMStateTuple):
        encoder_state_c = tf.concat((encoder_fw_state.c, encoder_bw_state.c), 1, name='bidirectional_concat_c')
        
        ## INSERTION 1. Variational reparameterization ##        
        
        encoder_state_h = tf.concat((encoder_fw_state.h, encoder_bw_state.h), 1, name='bidirectional_concat_h')
            # [batch_size, encoder_hidden_size*2]
        out_mean_w = tf.get_variable('out_mean_w', [encoder_hidden_size*2, latent_size], 
                                     initializer=tf.contrib.layers.xavier_initializer())
        out_mean_b = tf.get_variable('out_mean_b', [latent_size], 
                                     initializer=tf.contrib.layers.xavier_initializer())
        out_log_sigma_w = tf.get_variable('out_log_sigma_w', [encoder_hidden_size*2, latent_size], 
                                          initializer=tf.contrib.layers.xavier_initializer())
        out_log_sigma_b = tf.get_variable('out_log_sigma_b', [latent_size], 
                                          initializer=tf.contrib.layers.xavier_initializer())  
        z_mean = tf.add(tf.matmul(encoder_state_h, out_mean_w), out_mean_b)
        z_log_sigma_sq = tf.add(tf.matmul(encoder_state_h, out_log_sigma_w), out_log_sigma_b)
            # both are of [batch_size, latent_size]
        eps = tf.random_normal((batch_size, latent_size), 0, 1, dtype=tf.float32)
        z = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)), eps))
        
        encoder_state = LSTMStateTuple(c=encoder_state_c, h=z)
        
        #################################################
        
#         encoder_state = LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
        
    elif isinstance(encoder_fw_state, tf.Tensor):
        encoder_state = tf.concat((encoder_fw_state, encoder_bw_state), 1, name='bidirectional_concat')
        
with tf.variable_scope('Decoder') as scope:
    def output_fn(outputs):
        return tf.contrib.layers.linear(outputs, vocab_size, scope=scope)
    attention_states = tf.transpose(encoder_outputs, [1,0,2]) # [batch_size,max-time,hidden_size]
    (attention_keys,
     attention_values,
     attention_score_fn,
     attention_construct_fn) = seq2seq.prepare_attention(
        attention_states=attention_states,
        attention_option='bahdanau',
        num_units=decoder_hidden_size
    )
    decoder_fn_train = seq2seq.attention_decoder_fn_train(
        encoder_state=encoder_state,
        attention_keys=attention_keys,
        attention_values=attention_values,
        attention_score_fn=attention_score_fn,
        attention_construct_fn=attention_construct_fn,
        name='attention_decoder'
    )
    decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
        output_fn=output_fn,
        encoder_state=encoder_state,
        attention_keys=attention_keys,
        attention_values=attention_values,
        attention_score_fn=attention_score_fn,
        attention_construct_fn=attention_construct_fn,
        embeddings=embedding_matrix,
        start_of_sequence_id=indexer.get_index('EOS'),
        end_of_sequence_id=indexer.get_index('EOS'),
        maximum_length=tf.reduce_max(encoder_inputs_length) + 3,
        num_decoder_symbols=vocab_size
    )
    (decoder_outputs_train,
     decoder_state_train,
     decoder_context_state_train) = (
        seq2seq.dynamic_rnn_decoder(
            cell=decoder_cell,
            decoder_fn=decoder_fn_train,
            inputs=decoder_train_inputs_embedded,
            sequence_length=decoder_train_length,
            time_major=True,
            scope=scope
        )
    )
    decoder_logits_train = output_fn(decoder_outputs_train)
    decoder_prediction_train = tf.argmax(decoder_logits_train, axis=-1, name='decoder_prediction_train')
    scope.reuse_variables()
    (decoder_logits_inference,
     decoder_state_inference,
     decoder_context_state_inference) = (
        seq2seq.dynamic_rnn_decoder(
            cell=decoder_cell,
            decoder_fn=decoder_fn_inference,
            time_major=True,
            scope=scope
        )
    )
    decoder_prediction_inference = tf.argmax(decoder_logits_inference, axis=-1, name='decoder_prediction_inference')

    
logits = tf.transpose(decoder_logits_train, [1,0,2])
targets = tf.transpose(decoder_train_targets, [1,0])

generation_loss = seq2seq.sequence_loss(logits=logits, targets=targets, weights=loss_weights) 
    # average gen-loss
latent_loss = -0.5 * tf.reduce_mean(tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq), 1))  
    # average ltn-loss
loss = generation_loss + latent_loss

train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)

sess.run(tf.global_variables_initializer())    

In [11]:
def make_train_inputs(input_seq, target_seq):
    inputs_, inputs_length_ = batch(input_seq)
    targets_, targets_length_ = batch(target_seq)
    return {
        encoder_inputs: inputs_,
        encoder_inputs_length: inputs_length_,
        decoder_targets: targets_,
        decoder_targets_length: targets_length_,
    }

In [13]:
loss_track = []
num_batches = 50000
verbose = 1000
try:
    for b in range(num_batches):
        batch_inputs = get_batch(batch_size)
        fd = make_train_inputs(batch_inputs, batch_inputs)
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)
        if b==0 or b%verbose==0:
            print('batch {}'.format(b))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            for i,(e_in, dt_pred) in enumerate(zip(
                fd[encoder_inputs].T, # [max-time,batch_size] -> [batch_size,max-time]
                sess.run(decoder_prediction_train, fd).T
            )):
                print('  sample {}:'.format(i+1))
                print('    enc input           > {}'.format([w for w in to_sent(e_in) if w!='PAD']))
                print('    dec train predicted > {}'.format([w for w in to_sent(dt_pred) if w!='PAD']))
                if i>=2:
                    break
            print('\n')
except KeyboardInterrupt:
    print('training interrupted')
    
    
# batch 47000
#   minibatch loss: 0.0032705413177609444
#   sample 1:
#     enc input           > ['four', 'nine', 'two', 'seven', 'seven', 'four', 'eight', 'eight', 'four', 'six']
#     dec train predicted > ['four', 'nine', 'two', 'seven', 'seven', 'four', 'eight', 'eight', 'four', 'six', 'EOS']
#   sample 2:
#     enc input           > ['nine', 'four', 'one', 'two', 'five', 'five', 'six', 'six']
#     dec train predicted > ['nine', 'four', 'one', 'two', 'five', 'five', 'six', 'six', 'EOS']
#   sample 3:
#     enc input           > ['five', 'six', 'six', 'six', 'six', 'seven', 'nine', 'two', 'ten', 'six', 'three']
#     dec train predicted > ['five', 'six', 'six', 'six', 'six', 'seven', 'nine', 'two', 'ten', 'six', 'three', 'EOS']


# batch 48000
#   minibatch loss: 0.0031417866703122854
#   sample 1:
#     enc input           > ['two', 'seven', 'five', 'seven', 'ten', 'five', 'five']
#     dec train predicted > ['two', 'seven', 'five', 'seven', 'ten', 'five', 'five', 'EOS']
#   sample 2:
#     enc input           > ['four', 'two', 'three', 'three', 'six', 'four', 'three', 'two', 'nine', 'ten', 'eight', 'eight', 'one']
#     dec train predicted > ['four', 'two', 'three', 'three', 'six', 'four', 'three', 'two', 'nine', 'ten', 'eight', 'eight', 'one', 'EOS']
#   sample 3:
#     enc input           > ['two', 'two', 'ten', 'two', 'three', 'seven']
#     dec train predicted > ['two', 'two', 'ten', 'two', 'three', 'seven', 'EOS']


# batch 49000
#   minibatch loss: 0.002898368751630187
#   sample 1:
#     enc input           > ['six', 'four', 'eight', 'one', 'seven', 'nine', 'four', 'eight', 'two']
#     dec train predicted > ['six', 'four', 'eight', 'one', 'seven', 'nine', 'four', 'eight', 'two', 'EOS']
#   sample 2:
#     enc input           > ['six', 'one', 'five', 'nine', 'four', 'four', 'four', 'five', 'one']
#     dec train predicted > ['six', 'one', 'five', 'nine', 'four', 'four', 'four', 'five', 'one', 'EOS']
#   sample 3:
#     enc input           > ['five', 'six', 'seven', 'ten', 'ten', 'four', 'six', 'three', 'seven', 'eight', 'one']
#     dec train predicted > ['five', 'six', 'seven', 'ten', 'ten', 'four', 'six', 'three', 'seven', 'eight', 'one', 'EOS']
