## Ptr-Net 

* Vinyals et al. (2016)
* Task: sorting digit strings (vocab=range(0,10))
    * Input: ['2','3','1']
    * Output: ['1','2','3']


* Why: this is for when the outputs of the decoder are directly related to the inputs. E.g. in string sorting, the outputs' length should equal to that of the inputs. Thus, the model should only predicts a distribution over input strings, unlike regular enc-dec, which predicts distributions over a fixed vocabulary. This is not a problem for word sorting, for instance, but it is for sentence sorting, because there isn't a fixed vocabulary. 

In [52]:
# Add custom import path

import sys
sys.path.insert(0, '/home/jacobsuwang/Documents/UTA2018/NEURAL-NETS/ATTENTION/CODE/01-import-folder')

### MAKING DATA

In [60]:
import utils # utils.batch produces time-major inputs (examples below).
import random
import numpy as np

VOCAB = set(['PAD','EOS','1','2','3','4','5','6','7','8','9','0'])
NUMBERS = ['1','2','3','4','5','6','7','8','9','0']
MAX_LEN = len(NUMBERS)
WORD2IDX = {'PAD':0,'EOS':1,'1':2,'2':3,'3':4,'4':5,'5':6,'6':7,'7':8,'8':9,'9':10,'0':11}
IDX2WORD = {idx:word for word,idx in WORD2IDX.iteritems()}
BATCH_SIZE = 10

def random_datum(length):
    """
    Takes a length int, returns an encoded (input-sequence, output-sequence) tuple.
    """
    input_seq = np.random.choice(NUMBERS, length, replace=False)
    output_seq = sorted(input_seq)
    input_seq = map(lambda w:WORD2IDX[w], input_seq)
    output_seq = map(lambda w:WORD2IDX[w], output_seq)
    return input_seq, output_seq

def random_batch(batch_size, length_from=2, length_to=MAX_LEN):
    """
    Takes a batch size, and the length range for generated entries,
    returns an encoded (input-batch, output-batch), where each batch
    is a list of input-sequence/output-sequence.
    """
    if length_from >= length_to:
        raise ValueError('length_from must be strictly smaller than length_to')
    lengths = np.random.randint(length_from, length_to, size=batch_size)
    input_batch, output_batch = [], []
    for length in lengths:
        input_seq, output_seq = random_datum(length)
        input_batch.append(input_seq)
        output_batch.append(output_seq)
    return input_batch, output_batch

def decode_seq(seq):
    """
    Code to digit-string translation.
    """
    return map(lambda idx:IDX2WORD[idx], seq)

def decode_batch(batch):
    """
    Digit-string to code translation.
    """
    return map(lambda seq:decode_seq(seq), batch)

# Example:
# >> a, b = random_batch(3)
# >> print a
# >> print b
# [[10, 2], [6, 8, 10, 5, 2, 3, 7, 4, 9], [10, 7, 6, 8, 3, 11, 4, 5, 9]]
# [[2, 10], [2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 3, 4, 5, 6, 7, 8, 9, 10]]
# >> print decode_batch(a)
# [['9', '1'],
#  ['5', '7', '9', '4', '1', '2', '6', '3', '8'],
#  ['9', '6', '5', '7', '2', '0', '3', '4', '8']]
# >> decode_batch(b)
# [['1', '9'],
#  ['1', '2', '3', '4', '5', '6', '7', '8', '9'],
#  ['0', '2', '3', '4', '5', '6', '7', '8', '9']]
# >> utils.batch(a)
# (array([[10,  6, 10],
#         [ 2,  8,  7],
#         [ 0, 10,  6],
#         [ 0,  5,  8],
#         [ 0,  2,  3],
#         [ 0,  3, 11],
#         [ 0,  7,  4],
#         [ 0,  4,  5],
#         [ 0,  9,  9]], dtype=int32), [2, 9, 9])
# >> utils.batch(b)
# (array([[ 2,  2, 11],
#         [10,  3,  3],
#         [ 0,  4,  4],
#         [ 0,  5,  5],
#         [ 0,  6,  6],
#         [ 0,  7,  7],
#         [ 0,  8,  8],
#         [ 0,  9,  9],
#         [ 0, 10, 10]], dtype=int32), [2, 9, 9])

### MAKING MODEL

In [61]:
import numpy as np
import tensorflow as tf

from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple

In [62]:
# Graph

tf.reset_default_graph()
sess = tf.InteractiveSession()

vocab_size = len(VOCAB)
input_embedding_size = 20

encoder_hidden_units = 20
decoder_hidden_units = encoder_hidden_units * 2

#                    decoder 
#                    target
# 
# [] -> [] -> [#] -> [] -> []
#                     |    ^
# encoder             |    |    
# inputs              |    |        "Fish-hook" inputs to next step:
#    ^                v    |        u_j^i = v^T tanh(W1*e_j + W2*d_i) (Section 2.3)
#    |_____________ attention   <=  next_input = softmax(u^i)
#    

# Inputs
#
#   encoder_inputs: [max_time, batch_size]
#   encoder_inputs: [batch_size]
#   decoder_targets: [max_time, batch_size]
#
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs') 
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length') 
decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')

# Embeddings
#
#   embeddings: [vocab_size, emb_size]
#   encoder_inputs_embedded: [max_time, batch_size, emb_size]
#
embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -1.0, 1.0), dtype=tf.float32)
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

encoder_cell = LSTMCell(encoder_hidden_units)
((encoder_fw_outputs,encoder_bw_outputs), 
 (encoder_fw_final_state,encoder_bw_final_state)) = ( 
        tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                        cell_bw=encoder_cell,
                                        inputs=encoder_inputs_embedded,
                                        sequence_length=encoder_inputs_length,
                                        dtype=tf.float32, time_major=True)
    )

encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2) # concat on emb dim.
encoder_final_state_c = tf.concat((encoder_fw_final_state.c, encoder_bw_final_state.c), 1) # same thing.
encoder_final_state_h = tf.concat((encoder_fw_final_state.h, encoder_bw_final_state.h), 1)
encoder_final_state = LSTMStateTuple(
    c=encoder_final_state_c,
    h=encoder_final_state_h
)

decoder_cell = LSTMCell(decoder_hidden_units)
encoder_max_time, batch_size = tf.unstack(tf.shape(encoder_inputs)) # use runtime shape even w/ None.
decoder_lengths = encoder_inputs_length + 3 # +2 steps, +1 for EOS.

# Weights for decoder to make predictions
W = tf.Variable(tf.random_uniform([decoder_hidden_units, vocab_size], -1, 1), dtype=tf.float32) # for dec only!
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32)

# EOS and PAD
#
#    eos_step_embedded: initializing the first input to the decoder.
#    pad_step_embedded: the last "prediction" when the input length is exhausted. 
#
eos_time_slice = tf.ones([batch_size], dtype=tf.int32, name='EOS') 
pad_time_slice = tf.zeros([batch_size], dtype=tf.int32, name='PAD')
eos_step_embedded = tf.nn.embedding_lookup(embeddings, eos_time_slice) 
pad_step_embedded = tf.nn.embedding_lookup(embeddings, pad_time_slice)

# Loop initializer
#
#   handles the transitions in decoder after the first state
#              ___
#   output ->  |  |
#              # -|-> #
#               / |   ^
#          state  |   | <- next_input (inpt)
#   inp_embs_____attention
#

def loop_fn_initial():
    initial_elements_finished = (0 >= decoder_lengths) # all false (i.e. not done) at the init step.
    initial_input = eos_step_embedded                  # it's a [batch_size] length boolean vector.   
        # "input": it's the input for the next state.
        # in this case, the first cell of the decoder.
    initial_cell_state = encoder_final_state
    initial_cell_output = None 
    initial_loop_state = None 
    return (initial_elements_finished,
            initial_input,
            initial_cell_state,
            initial_cell_output,
            initial_loop_state)

# Weights for the Ptr-attention
#
#   u_j^i = v^T tanh(W1*e_j + W2*d_i) (Section 2.3)
#
W1 = tf.Variable(tf.random_uniform([encoder_hidden_units, encoder_hidden_units], -1, 1),
                 dtype=tf.float32) 
W2 = tf.Variable(tf.random_uniform([decoder_hidden_units, encoder_hidden_units], -1, 1), 
                 dtype=tf.float32) 
v = tf.Variable(tf.random_uniform([encoder_hidden_units, 1], -1, 1),
                dtype=tf.float32) 

def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
    
    def get_next_input(): 
        mt, bc, _ = tf.unstack(tf.shape(encoder_inputs_embedded)) # get runtime shape of max_time, batch.
        EW1 = tf.reshape(tf.tensordot(encoder_inputs_embedded, W1, axes=[[2],[0]]),
                         [mt, bc, encoder_hidden_units]) # [max_time, batch_size, emb_size]
        DW2 = tf.matmul(previous_state.h, W2) # [batch_size, emb_size]
        EW1_add_DW2 = tf.add(EW1, DW2)
        # v^T * tanh(W1*e_j + W2*d_i) in batch mode.
        attention_mat = tf.reshape(tf.nn.tanh(tf.squeeze(tf.tensordot(EW1_add_DW2, v, axes=[[2],[0]]), 
                                                         axis=2)), [mt,bc]) # unnormalized attention mat, [mt, bc]
        attention_norm_mat = tf.nn.softmax(attention_mat, dim=0) # cols sum to 1 now.
        selector = tf.one_hot(tf.argmax(attention_norm_mat, axis=0), depth=encoder_max_time,
                              on_value=1.0, off_value=0.0, axis=0) # selects max-attended input embs.
        inputs_embedded_selected = tf.transpose(
            tf.multiply(
                tf.transpose(encoder_inputs_embedded, [2,0,1]), 
                selector), 
            [1,2,0]
        ) # transposing: to allow for multiply broadcast.
        inputs_embedded_selected = tf.reduce_sum(
            tf.reshape(inputs_embedded_selected, [mt, bc, encoder_hidden_units]), # explicitly interpret dims.
            axis=0 # compress max_time dimension to only output selected embeddings.
        )   
        next_input = inputs_embedded_selected
        return next_input
    
    elements_finished = (time >= decoder_lengths)
        # this returns a boolean tensor, e.g. [1, 1, 1, 0]
        # this means the first three steps are done, but not the last.
        # when all the steps are done, i.e. time (the real time) is larger than
        # the specified max decoding steps, the vector is all 1.
        # then the next line will return 1.    
    finished = tf.reduce_all(elements_finished) 
    inpt = tf.cond(finished, lambda: pad_step_embedded, get_next_input)
        # if finished, return a pad for next input (i.e. the feed to next step)
        # otherwise, return get_next_input as usual.
    state = previous_state
    output = previous_output
    loop_state = None
    # outputs:
    # elements_finished: a [batch_size] boolean vector.
    # inpt: [batch_size, emb_size] tensor for the next cell.
    # state: (c,h) tuple, raw_rnn takes care of it.
    # output: stored [batch_size, emb_size] tensor.
    # loop_state: rnn_raw takes care of it.
    return (elements_finished,
            inpt, 
            state,
            output,
            loop_state)  

def loop_fn(time, previous_output, previous_state, previous_loop_state):
    # time: an int32 scalar raw_rnn uses to keep track of time-steps internally.
    # previous_output: [max_time, batch_size, emb_size] tensor.
    # previous_state: (c,h) tuple.
    # previous_loop_state: raw_rnn uses to keep track of where it is in the loop (automatic).
    if previous_state is None:
        assert previous_output is None and previous_state is None
        return loop_fn_initial()
    else:
        return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

    
decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(decoder_cell, loop_fn)
    # *_ta: the RNN output (TensorArray <- for dynamic use)
    # *_final_state: 2-tuple of [batch_size, emb_size] (i.e. c and h). of no use for seq2seq.
    # _: final_loop_state, which no one gives a cupcake (used internally by *.raw_rnn backend).
decoder_outputs = decoder_outputs_ta.stack() # [max_time, batch_size, emb_concat] 

decoder_max_step, decoder_batch_size, decoder_dim = tf.unstack(tf.shape(decoder_outputs))
decoder_outputs_flat = tf.reshape(decoder_outputs, (-1, decoder_dim))
    # for matmul, we do
    # [max_time, batch_size, emb_concat], [max_time*batch_size, emb_concat]
decoder_logits_flat = tf.add(tf.matmul(decoder_outputs_flat, W), b)
decoder_logits = tf.reshape(decoder_logits_flat, (decoder_max_step, decoder_batch_size, vocab_size))
    # put it back into the original shaping scheme.
decoder_prediction = tf.cast(tf.argmax(decoder_logits, 2), dtype=tf.int32)

# Report accuracy here
correct_raw = tf.cast(tf.equal(decoder_prediction, decoder_targets), tf.int32)
mask = tf.cast(tf.not_equal(decoder_targets, WORD2IDX['PAD']), tf.int32) # EOSs are not 0ed out, there are BATCH_SIZE of them.
total_seqlen = tf.cast(tf.reduce_sum(encoder_inputs_length), tf.float32)
correct = tf.multiply(correct_raw, mask)
accuracy = tf.cast(tf.reduce_sum(correct)-BATCH_SIZE, tf.float32) / total_seqlen # -BATCH_SIZE correction
# accuracy = tf.cast(tf.reduce_sum(correct), tf.float32) / total_seqlen # w/o the correction.

# Optimization
stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32),
    logits=decoder_logits
)
loss = tf.reduce_mean(stepwise_cross_entropy)
train_op = tf.train.AdamOptimizer().minimize(loss)

init = tf.global_variables_initializer()
sess.run(init)

### EVALUATING MODEL

In [63]:
def next_feed(batch_size):
    batch_enc, batch_dec = random_batch(batch_size)
    encoder_inputs_, encoder_inputs_lengths_ = utils.batch(batch_enc)
    decoder_targets_, _ = utils.batch([seq + [WORD2IDX['EOS']] + [WORD2IDX['PAD']]*2 for seq in batch_dec])
    return {
        encoder_inputs: encoder_inputs_,
        encoder_inputs_length: encoder_inputs_lengths_,
        decoder_targets: decoder_targets_
    }

loss_track = []

max_batches = 5001
batches_in_epoch = 1000
num_test_batches = 100

try:
    for batch in range(max_batches):
        fd = next_feed(BATCH_SIZE)
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)
        if batch == 0 or batch % batches_in_epoch == 0:
            print('Batch {}'.format(batch))
            print('  minibatch loss: {} | accuracy {}'.format(*sess.run([loss, accuracy], fd)))
            # predict_ = sess.run(decoder_prediction, fd)
            predict_, lengths_ = sess.run([decoder_prediction, encoder_inputs_length], fd) # make use of seqlen
            
            for i, (inp, pred, tar, length) in enumerate(zip(fd[encoder_inputs].T, predict_.T, fd[decoder_targets].T, lengths_)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(decode_seq(inp)))
                print('    predicted > {}'.format(decode_seq(pred)))
                print('    target    > {}'.format(decode_seq(tar)))
                if i >= 2:
                    break
            print
            
    # EVALUATE ON A BIG TEST SET
    loss_track, accuracy_track = [], []
    for _ in range(num_test_batches):
        fd = next_feed(BATCH_SIZE)
        l, a = sess.run([loss, accuracy], fd)
        loss_track.append(l)
        accuracy_track.append(a)
    print('Evaluation results (on {} batches):'.format(num_test_batches))
    print('  average loss: {} | average accuracy {}'.format(np.mean(loss_track), np.mean(accuracy_track)))
    print
    
    
except KeyboardInterrupt:
    print('training interrupted')

Batch 0
  minibatch loss: 2.58508086205 | accuracy -0.159999996424
  sample 1:
    input     > ['2', '7', '6', '9', '5', '4', '0', '1']
    predicted > ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0']
    target    > ['0', '1', '2', '4', '5', '6', '7', '9', 'EOS', 'PAD', 'PAD']
  sample 2:
    input     > ['5', '9', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD']
    predicted > ['1', '9', '9', '9', '9', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD']
    target    > ['5', '9', 'EOS', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD']
  sample 3:
    input     > ['8', '6', '3', '4', '9', '2', 'PAD', 'PAD']
    predicted > ['4', '0', '0', '0', '0', '0', '0', '0', '0', 'PAD', 'PAD']
    target    > ['2', '3', '4', '6', '8', '9', 'EOS', 'PAD', 'PAD', 'PAD', 'PAD']

Batch 1000
  minibatch loss: 0.247449174523 | accuracy 1.0
  sample 1:
    input     > ['6', '5', '8', '4', '2', '7', '1', '0', '9']
    predicted > ['0', '1', '2', '4', '5', '6', '7', '8', '9', 'EOS', 'PAD', 'PAD']
    target 