### Task: sorting digits (e.g. [3,1,2] -> [1,2,3])

* Encoder-Decoder + Attention (Bahdanau et al. 2015)

In [1]:
# Add custom import path

import sys
sys.path.insert(0, '/home/jacobsuwang/Documents/UTA2018/NEURAL-NETS/ATTENTION/CODE/01-import-folder')

### MAKING DATA

In [68]:
import utils
import random
import numpy as np

VOCAB = set(['PAD','EOS','1','2','3','4','5','6','7','8','9','0'])
NUMBERS = ['1','2','3','4','5','6','7','8','9','0']
MAX_LEN = len(NUMBERS) + 2
# WORD2IDX = {'PAD':0,'EOS':1,'1':2,'2':3,'3':4,'4':5,'5':6,'6':7,'7':8,'8':9,'9':10,'0':11}
WORD2IDX = {'PAD':0,'EOS':1,'1':2,'2':3,'3':4,'4':5,'5':6,'6':7,'7':8,'8':9,'9':10,'0':11}
IDX2WORD = {idx:word for word,idx in WORD2IDX.iteritems()}
# BATCH_SIZE = 10


def random_datum(n):
    input_seq = list(np.random.choice(NUMBERS, n, replace=False)) 
        # e.g. ['5', '6', '3', '9', '1'].
    sorted_seq = sorted(input_seq)
    output_seq = [input_seq.index(word)+2 for word in sorted_seq]
        # index in ascending.
        # e.g. [4, 2, 0, 1, 3], for the input above.
    input_seq = input_seq
    return input_seq, output_seq

def encode_seq(seq):
    return [WORD2IDX[word] for word in seq]

def decode_seq(seq):
    return [IDX2WORD[idx] for idx in seq]

def decode_order(seq):
    return [aug_idx-2 for aug_idx in seq]

def decode_by_index(seq, idx_seq, end_idx, correction=-2): # -2: PAD and EOS
    decoded = []
    for idx in range(end_idx):
        decoded.append(seq[idx_seq[idx]+correction])
    return decoded

def random_batch(batch_size, input_length_from=2, input_length_to=MAX_LEN-2):
    if input_length_from >= input_length_to:
        raise ValueError('length_from >= length_to')
    input_lengths = np.random.randint(input_length_from, input_length_to, size=batch_size)
    input_batch, output_batch = [], []
    for length in input_lengths:
        input_seq, output_seq = random_datum(length)
        input_batch.append(encode_seq(input_seq))
        output_batch.append(output_seq)
    return input_batch, output_batch    


### MAKING MODEL

In [64]:
import math

import numpy as np
import tensorflow as tf
import tensorflow.contrib.seq2seq as seq2seq
from tensorflow.contrib.layers import safe_embedding_lookup_sparse as embedding_lookup_unique
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple, GRUCell

In [69]:
# Graph

tf.reset_default_graph()
sess = tf.InteractiveSession()

# Set configuration

# digit_from = 2
# digit_to = 6+1
# word_from = 7
# word_to = 11+1
# seqlen_from = 3
# seqlen_to = 8

batch_size = 10 
vocab_size = len(VOCAB)
input_embedding_size = 10

encoder_hidden_units = 20
decoder_hidden_units = encoder_hidden_units * 2 # because encoder is going to be bidirectional.

attention = True     # togglable
bidirectional = True # currently hardcoded

encoder_cell = LSTMCell(encoder_hidden_units)
decoder_cell = LSTMCell(decoder_hidden_units)

# _init_placeholder()

encoder_inputs = tf.placeholder(shape=(None,None), dtype=tf.int32, name='encoder_inputs') # [max_time, batch_size]
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length') 
decoder_targets = tf.placeholder(shape=(None,None), dtype=tf.int32, name='decoder_targets')
decoder_targets_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='decoder_inputs_length')

# _init_decoder_train_connectors()
#  - adding EOS to placehoder (and get seqlen right).

with tf.name_scope('DecoderTrainFeeds'):
    sequence_size, batch_size_ = tf.unstack(tf.shape(decoder_targets)) # [max_time, batch_size]
    EOS_SLICE = tf.ones([1, batch_size_], dtype=tf.int32) * WORD2IDX['EOS']
    PAD_SLICE = tf.ones([1, batch_size_], dtype=tf.int32) * WORD2IDX['PAD']
    decoder_train_inputs = tf.concat([EOS_SLICE, decoder_targets], axis=0) 
        # add EOS to the beginning.
        # node that this only changes the structure of the graph, not real input.
    decoder_train_length = decoder_targets_length + 1 # and adjust length accordingly. 
#     decoder_train_length = decoder_targets_length # and adjust length accordingly.     
    decoder_train_targets = tf.concat([decoder_targets, PAD_SLICE], axis=0) # add PAD to the end.
    decoder_train_targets_seq_len, _ = tf.unstack(tf.shape(decoder_train_targets)) # seq_len = max_time
    # the mask picks out EOS to 1 (the rest 0)
    # so later when you add it to decoder_train_targets,
    # the end-of-seq PAD placeholder will be replaced
    # by 1, which is EOS.
    decoder_train_targets_eos_mask = tf.one_hot(decoder_train_length - 1, # indices of EOS tokens.
                                                decoder_train_targets_seq_len, # depth: dim of one-hot vecs.
                                                on_value=WORD2IDX['EOS'], off_value=WORD2IDX['PAD'], # 1 for on, 0 for off.
                                                dtype=tf.int32)
    decoder_train_targets_eos_mask = tf.transpose(decoder_train_targets_eos_mask, [1, 0]) 
        # one_hot naturally produces [batch_size, max_time]
        # we translate it to time-major, i.e. [max_time, batch_size]?
    decoder_train_targets = tf.add(decoder_train_targets,
                                   decoder_train_targets_eos_mask) # add EOS (in index) to end of target sequence
    loss_weights = tf.ones([
        batch_size,
        tf.reduce_max(decoder_train_length) # max_time
    ], dtype=tf.float32, name='loss_weights') # [batch_size, max_time]
        # on init, all weights are equally important.

# _init_embeddings()
#  - looking up embedding matrix.

with tf.variable_scope('embedding') as scope:
    sqrt3 = math.sqrt(3) # unif(-sqrt(3),sqrt(3)) has var = 1.
    initializer = tf.random_uniform_initializer(-sqrt3, sqrt3)
    embedding_matrix = tf.get_variable(
        name='embedding_matrix',
        shape=[vocab_size, input_embedding_size],
        initializer=initializer,
        dtype=tf.float32
    )
    encoder_inputs_embedded = tf.nn.embedding_lookup(embedding_matrix, encoder_inputs)
    decoder_train_inputs_embedded = tf.nn.embedding_lookup(embedding_matrix, decoder_train_inputs)
        # decoder_train_inputs: decoder_targets prepended with EOS.

# _init_bidirectional_encoder()
#  - make encoder_state: [batch_size, hidden_size] 
#  - make encoder_outputs: [max_time, batch_size, emb_size]
#  h_t = f(x_t, h_t-1)    # the LSTM cell.
#  c = q({h_1, ..., h_T}) # the context vector.

with tf.variable_scope('BidirectionalEncoder') as scope:
    encoder_cell = LSTMCell(encoder_hidden_units)
    ((encoder_fw_outputs,encoder_bw_outputs), # both have [max_time, batch_size, emb_size]
     (encoder_fw_state,encoder_bw_state)) = ( # (final) state tuples: (c=[batch_size,emb_size],h=same)
            tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                            cell_bw=encoder_cell,
                                            inputs=encoder_inputs_embedded,
                                            sequence_length=encoder_inputs_length,
                                            dtype=tf.float32, time_major=True)
        )  
    encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2) # concat on emb dim.
    if isinstance(encoder_fw_state, LSTMStateTuple):
        encoder_state_c = tf.concat((encoder_fw_state.c, encoder_bw_state.c), 1, name='bidirectional_concat_c')
        encoder_state_h = tf.concat((encoder_fw_state.h, encoder_bw_state.h), 1, name='bidirectional_concat_h')
        encoder_state = LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
    elif isinstance(encoder_fw_state, tf.Tensor):
        encoder_state = tf.concat((encoder_fw_state, encoder_bw_state), 1, name='bidirectional_concat')
    
# _init_decoder()

with tf.variable_scope('Decoder') as scope:
    def output_fn(outputs):
        return tf.contrib.layers.linear(outputs, vocab_size, scope=scope)
    if not attention:
        # no-attention: decoding conditions only on a lump-sum encoded final context vector c.
        # p(y) = \prod_t p(y_t | {y_1, ..., y_t-1}, c) # decoder's soft prediction (Bahdanau15:(2)).
        #
        decoder_fn_train = seq2seq.simple_decoder_fn_train(encoder_state=encoder_state)
            # simple_decoder_fn_train: made for seq2seq.dynamic_rnn_decoder later.
        decoder_fn_inference = seq2seq.simple_decoder_fn_inference(
            output_fn=output_fn,
            encoder_state=encoder_state,
            embeddings=embedding_matrix,
            start_of_sequence_id=WORD2IDX['EOS'],
            end_of_sequence_id=WORD2IDX['EOS'],
            maximum_length=tf.reduce_max(encoder_inputs_length) + 3,
            num_decoder_symbols=vocab_size
        )
    else:
        # attention: decoding conditions on a distinct context vector i at time-step i.
        # p(y_i) = p(y_i | {y_1, ..., y_i-1}, c_i) # (Bahdanau15:(4))
        # for this we need to feed *all* the hidden states of the encoder to the attention layer.
        # c_i = \sum_j a_ij * h_j, j: index over all hidden states (Bahdanau15:(5)).
        #   a_ij = softmax(e_ij) (Bahdanau15:(6)),
        #     interpretation: the probability that state i (dec) is aligned with state j (enc).
        #   e_ij = a(s_i-1, h_j), where s_i-1 is the *decoder's hidden state*.
        #     the function a(dec-hid, enc-hid) here is configured as an FFNN in Bahdanau15. 
        # 
        attention_states = encoder_state_h
            # ematvey's original below is incorrect as per Bahdanau15,
            # the attention layer takes the encoder *hidden states*, not *outputs*.
            # attention_states = tf.transpose(encoder_outputs, [1, 0, 2]) # -> [batch_size, max_time, num_units]
        (attention_keys,     # `to be compared with target states' (Q: target state? more like supervision for the attention network.)
         attention_values,   # `to be used to construct context vectors' (i.e. c_i in equation (4,5))
         attention_score_fn, # `to compute similarity between key and target states' (i.e. a(dec-hid,enc-hid))
         attention_construct_fn) = seq2seq.prepare_attention( # construct_fn: build attention states (i.e. a_ij).
            attention_states=attention_states,
            attention_option='bahdanau',
            num_units=decoder_hidden_units # num_units = hidden_size 
                # Q: could this be a different number?.          dec-hid  enc-hid
                # no. because we are feeding s_i here for e_ij = a(s_i-1, h_j).
        )
        decoder_fn_train = seq2seq.attention_decoder_fn_train( # simple_decoder_fn + attention.
            encoder_state=encoder_state,
            attention_keys=attention_keys,
            attention_values=attention_values,
            attention_score_fn=attention_score_fn,
            attention_construct_fn=attention_construct_fn,
            name='attention_decoder'
        )
        decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
            output_fn=output_fn,
            encoder_state=encoder_state,
            attention_keys=attention_keys,
            attention_values=attention_values,
            attention_score_fn=attention_score_fn,
            attention_construct_fn=attention_construct_fn,
            embeddings=embedding_matrix,
            start_of_sequence_id=WORD2IDX['EOS'],
            end_of_sequence_id=WORD2IDX['EOS'],
            maximum_length=tf.reduce_max(encoder_inputs_length) + 3,
            num_decoder_symbols=vocab_size
        )
    (decoder_outputs_train,
     decoder_state_train,
     decoder_context_state_train) = (
        seq2seq.dynamic_rnn_decoder(
            cell=decoder_cell,
            decoder_fn=decoder_fn_train,
            inputs=decoder_train_inputs_embedded,
            sequence_length=decoder_train_length,
            time_major=True,
            scope=scope
        )
    )
    decoder_logits_train = output_fn(decoder_outputs_train)
    decoder_prediction_train = tf.argmax(decoder_logits_train, axis=-1, name='decoder_prediction_train')
    scope.reuse_variables()
    (decoder_logits_inference,
     decoder_state_inference,
     decoder_context_state_inference) = (
        seq2seq.dynamic_rnn_decoder(
            cell=decoder_cell,
            decoder_fn=decoder_fn_inference,
            time_major=True,
            scope=scope
        )
    )
    decoder_prediction_inference = tf.argmax(decoder_logits_inference, axis=-1, name='decoder_prediction_inference')
    
# [SU] report accuracy here
correct_raw = tf.cast(tf.equal(tf.cast(decoder_prediction_train, tf.int32), decoder_train_targets), tf.int32)
mask = tf.cast(tf.not_equal(decoder_train_targets, WORD2IDX['PAD']), tf.int32) # EOSs are not 0ed out, there are BATCH_SIZE of them.
total_seqlen = tf.cast(tf.reduce_sum(encoder_inputs_length), tf.float32)
correct = tf.multiply(correct_raw, mask)
accuracy = tf.cast(tf.reduce_sum(correct)-batch_size, tf.float32) / total_seqlen # -BATCH_SIZE correction
    
# _init_optimizer()

logits = tf.transpose(decoder_logits_train, [1, 0, 2]) # to [batch_size, max_time, num_units]
targets = tf.transpose(decoder_train_targets, [1, 0])  # to [batch_size, max_time]
loss = seq2seq.sequence_loss(logits=logits, targets=targets, weights=loss_weights)
    # `Weighted cross-entropy loss for a sequence of logits (per example)'
    # logits: [batch_size, max_time, num_units]
    # targets: [batch_size, max_time]
    # weights: [batch_size, max_time]
train_op = tf.train.AdamOptimizer().minimize(loss)

# run training

init = tf.global_variables_initializer()
sess.run(init)

In [45]:
tf.shape(decoder_targets)

<tf.Tensor 'Shape_1:0' shape=(2,) dtype=int32>

In [41]:
tf.unstack(tf.shape(decoder_targets))

[<tf.Tensor 'unstack:0' shape=() dtype=int32>,
 <tf.Tensor 'unstack:1' shape=() dtype=int32>]

In [22]:
decoder_prediction_train.shape

TensorShape([Dimension(None), Dimension(None)])

In [70]:
def make_train_inputs(input_seq, target_seq):
    # batch_enc, batch_dec = random_batch(word_from,word_to,digit_from,digit_to,seqlen_from,seqlen_to,batch_size)
        # this is called in ematvey's code as:
        # batch_data = next(batches)
        # fd = model.make_train_inputs(batch_data, batch_data)
    inputs_, inputs_length_ = utils.batch(input_seq) 
        # equiv encoder_inputs_, encoder_inputs_lengths_ = utils.batch(batch_enc)
    targets_, targets_length_ = utils.batch(target_seq)
        # equiv decoder_targets_, _ = utils.batch([seq + [word2idx['EOS']] + [word2idx['PAD']]*2 for seq in batch_dec])
        # the EOS addition is done in a function above, so no need here.
    return {
        encoder_inputs: inputs_,
        encoder_inputs_length: inputs_length_,
        decoder_targets: targets_,
        decoder_targets_length: targets_length_
    }

loss_track = []
max_batches = 5000
batches_in_epoch=1000
num_test_batches = 100

try:
    for batch in range(max_batches+1):
#         batch_enc, batch_dec = random_batch(word_from,word_to,digit_from,digit_to,seqlen_from,seqlen_to,batch_size)
        batch_enc, batch_dec = random_batch(batch_size)
        fd = make_train_inputs(batch_enc, batch_dec) 
            # ematvey: ..(batch_data, batch_data)
            # because he does copy task, and i do translation.
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)
        
        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
#             print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            print('  minibatch loss: {} | accuracy: {}'.format(*sess.run([loss, accuracy], fd)))
            for i, (e_in, dt_pred, length) in enumerate(zip(
                    fd[encoder_inputs].T,
                    sess.run(decoder_prediction_train, fd).T,
                    sess.run(encoder_inputs_length, fd).T
                )):
                decoded_e_in = decode_seq(e_in)
                print('  sample {}:'.format(i + 1))
                print('    enc input     > {}'.format(decoded_e_in))
                print('    dec train predicted > {}'.format(decode_by_index(decoded_e_in, dt_pred, end_idx=length)))
                if i >= 2:
                    break
            print
    
#     # AFTER TRAING: want to match shape
#     # dec_pred_ has an EOS at the end of seq
#     # at the early stage, the network doesn't predict
#     # this EOS correctly, so dec_pred_'s like, a whole bunch of random numbers.
#     # after training, it gets good, and almost always predict EOS correctly.
#     batch_enc, batch_dec = random_batch(batch_size)
#     fd = make_train_inputs(batch_enc, batch_dec) 
#     enc_inp_ = sess.run(encoder_inputs, fd)
#     dec_pred_, dec_tar_, dec_tr_tar_ = sess.run([decoder_prediction_train, decoder_targets, decoder_train_targets], fd)
#     print 'encoder input:'
#     print enc_inp_
#     print enc_inp_.shape
#     print 'decoder pred:'
#     print dec_pred_
#     print dec_pred_.shape
#     print 'decoder target:'
#     print dec_tar_
#     print dec_tar_.shape
#     print 'decoder train target:'
#     print dec_tr_tar_
#     print dec_tr_tar_.shape
#     assert 1==0

    # EVALUATE ON A BIG TEST SET
    loss_track, accuracy_track = [], []
    for _ in range(num_test_batches):
        batch_enc, batch_dec = random_batch(batch_size)
        fd = make_train_inputs(batch_enc, batch_dec) 
        l, a = sess.run([loss, accuracy], fd)
        loss_track.append(l)
        accuracy_track.append(a)
    print('Evaluation results (on {} batches):'.format(num_test_batches))
    print('  average loss: {} | average accuracy {}'.format(np.mean(loss_track), np.mean(accuracy_track)))
    print    

            
except KeyboardInterrupt:
    print('training interrupted')        

batch 0
  minibatch loss: 2.48968148232 | accuracy: -0.0895522385836
  sample 1:
    enc input     > ['4', '8', '2', '9', '3', '5', '1', '6', '7']
    dec train predicted > ['4', '1', '1', '1', '7', '7', '6', '6', '6']
  sample 2:
    enc input     > ['3', '8', '2', '4', '7', '9', '0', '5', 'PAD']
    dec train predicted > ['7', '0', '0', 'PAD', 'PAD', 'PAD', 'PAD', '5']
  sample 3:
    enc input     > ['1', '0', '5', '9', '6', '2', '7', '8', '4']
    dec train predicted > ['0', '0', '5', '6', '8', '6', '4', '4', '4']

batch 1000
  minibatch loss: 0.723235249519 | accuracy: 0.804878056049
  sample 1:
    enc input     > ['5', '4', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD']
    dec train predicted > ['4', '5']
  sample 2:
    enc input     > ['5', '8', '0', '9', '4', 'PAD', 'PAD']
    dec train predicted > ['0', '4', '5', '9', '9']
  sample 3:
    enc input     > ['0', '3', '6', '5', 'PAD', 'PAD', 'PAD']
    dec train predicted > ['0', '3', '6', '6']

batch 2000
  minibatch loss: 0.495182573795