In [98]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer as xinit

- input size : [BxL]
- hdim : [10]
- num_pointers : [1] (2)
- batch_size : 32

In [99]:
B = 32
L = 20
hdim = 10
num_pointers = 2
vocab_size = 11

In [100]:
DECODE_SYM = 10

In [101]:
tf.reset_default_graph()

In [102]:
ip_seq = tf.placeholder(dtype=tf.int64, shape=[B, L], name='ip_seq')

In [103]:
label = tf.placeholder(dtype=tf.int64, shape=[num_pointers, B], name='label')

In [104]:
emb_mat = tf.get_variable('emb', [vocab_size, hdim],
                         initializer=xinit())

In [105]:
emb_ip_seq = tf.nn.embedding_lookup(emb_mat, ip_seq)

In [106]:
with tf.variable_scope('encoder'):
    ecell = tf.contrib.rnn.LSTMCell(hdim)
    enc_init_state = ecell.zero_state(B, dtype=tf.float32)
    enc_outputs, final_enc_state = tf.nn.dynamic_rnn(ecell,
                                                    inputs=emb_ip_seq,
                                                    initial_state=enc_init_state)

In [107]:
DECODE_SYM_TF = tf.constant(DECODE_SYM,
                           shape=[B],
                           dtype=tf.int64)

In [108]:
emb_dec_sym = tf.nn.embedding_lookup(emb_mat, DECODE_SYM_TF)

In [109]:
dec_state = final_enc_state
dec_input = emb_dec_sym

In [110]:
with tf.variable_scope('decoder'):
    dcell = tf.contrib.rnn.LSTMCell(hdim)
    a_j = []
    dec_outputs = []
    logits = []
    range_ = tf.range(start=0, limit=B, dtype=tf.int64)
    
    for i in range(num_pointers):
        if i > 0:
            tf.get_variable_scope().reuse_variables()
            
        dec_output, dec_state = dcell(dec_input, dec_state)

        v = tf.get_variable('v_blend', [hdim, 1], initializer=xinit())
        We = tf.get_variable(dtype=tf.float32, shape=[hdim, hdim], name='We')
        Wd = tf.get_variable(dtype=tf.float32, shape=[hdim, hdim], name='Wd')

        enc_outputs_reshaped = tf.reshape( enc_outputs, [B*L, hdim] )
        u_j = tf.reshape(tf.matmul(enc_outputs_reshaped, We), [B,L,hdim]) + tf.expand_dims(tf.matmul(dec_state.c, Wd), axis=1)
        u_j = tf.matmul(tf.reshape(u_j, [B*L, hdim]), v)

        logit = tf.reshape(u_j, [B,L])
        a_j_i = tf.nn.softmax(logit)
        pointer_idx = tf.stack([range_, tf.argmax(a_j_i, axis=1)])
        pointer_idx = tf.transpose(pointer_idx)
        
        dec_input = tf.nn.embedding_lookup(emb_mat,
                                           tf.gather_nd(ip_seq, pointer_idx))
        logits.append(logit)
        dec_outputs.append(tf.argmax(a_j_i, axis=1))

In [120]:
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, 
                                               logits=tf.stack(logits))

In [121]:
loss = tf.reduce_mean(cross_entropy)

In [122]:
optimizer = tf.train.AdadeltaOptimizer(learning_rate=1.0)

In [123]:
train_op = optimizer.minimize(loss)

In [124]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [136]:
num_epochs = 100000

In [139]:
for i in range(num_epochs):
    avg_loss = 0
    for j in range(len(train_batches)):
        loss_v, _ = sess.run([loss, train_op], feed_dict= {
            ip_seq : train_batches[0][0],
            label : train_batches[0][1]
        })
        avg_loss += loss_v
    if i%30 == 0 and i:
        print('{} : {}'.format(i, avg_loss/len(train_batches)))
        

30 : 0.6958453375846148
60 : 0.6953677218407393
90 : 0.6944473795592785
120 : 0.6941793970763683
150 : 0.6940209772437811
180 : 0.6941094025969505
210 : 0.6939002405852079
240 : 0.6940827872604132
270 : 0.6938756164163351
300 : 0.6938037294894457
330 : 0.6938047539442778
360 : 0.6938310135155916
390 : 0.6939037032425404


KeyboardInterrupt: 

In [146]:
op = sess.run([logits], feed_dict = {
    ip_seq : train_batches[10][0],
    label : train_batches[10][1]
})

In [145]:
train_batches[10][1]

array([[ 5,  4,  4,  4,  5,  6,  6,  6,  5,  6,  6,  4,  6,  5,  6,  5,  5,
         6,  4,  5,  6,  4,  5,  5,  5,  6,  4,  5,  6,  5,  5,  5],
       [ 9,  9,  7,  7,  9, 11, 10, 10, 10,  9, 10,  7,  9, 10, 10,  8, 10,
        11,  9,  8, 10,  8, 10,  9,  9,  9,  9, 10, 10, 10,  8, 10]])

In [152]:
train_batches[10][0]

array([[5, 3, 2, 4, 3, 6, 8, 6, 9, 6, 3, 3, 3, 1, 3, 4, 0, 0, 0, 0],
       [1, 5, 2, 5, 8, 9, 6, 6, 9, 9, 3, 4, 3, 3, 2, 1, 0, 0, 0, 0],
       [1, 3, 3, 3, 7, 6, 9, 6, 2, 1, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0],
       [3, 3, 4, 3, 7, 7, 8, 8, 2, 5, 3, 4, 1, 3, 0, 0, 0, 0, 0, 0],
       [3, 2, 3, 2, 3, 6, 9, 9, 7, 6, 5, 4, 2, 2, 2, 0, 0, 0, 0, 0],
       [3, 1, 3, 4, 2, 4, 6, 9, 7, 6, 7, 7, 4, 1, 1, 3, 4, 0, 0, 0],
       [1, 5, 2, 1, 5, 5, 7, 9, 9, 8, 9, 1, 1, 2, 3, 1, 0, 0, 0, 0],
       [4, 5, 5, 3, 3, 4, 8, 6, 7, 6, 8, 2, 2, 1, 5, 0, 0, 0, 0, 0],
       [4, 1, 4, 3, 3, 7, 8, 9, 8, 9, 8, 1, 4, 2, 2, 0, 0, 0, 0, 0],
       [2, 4, 4, 2, 2, 5, 6, 6, 9, 6, 5, 3, 2, 2, 1, 0, 0, 0, 0, 0],
       [1, 1, 5, 3, 5, 4, 7, 9, 9, 6, 9, 3, 3, 2, 4, 3, 0, 0, 0, 0],
       [2, 3, 5, 3, 6, 8, 8, 8, 3, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0],
       [5, 5, 5, 5, 5, 2, 9, 9, 8, 9, 3, 1, 4, 3, 4, 5, 0, 0, 0, 0],
       [1, 4, 1, 1, 1, 7, 8, 6, 6, 9, 6, 5, 5, 5, 3, 0, 0, 0, 0, 0],
       [4, 4, 1, 4, 5, 3, 7, 7, 9,

## DATA

In [126]:
import random
import numpy as np

In [127]:
def generate_nested_sequence(length, min_seglen=5, max_seglen=10):
    """Generate low-high-low sequence, with indexes of the first/last high/middle elements"""

    # Low (1-5) vs. High (6-10)
    seq_before = [(random.randint(1,5)) for x in range(random.randint(min_seglen, max_seglen))]
    seq_during = [(random.randint(6,9)) for x in range(random.randint(min_seglen, max_seglen))]
    seq_after = [random.randint(1,5) for x in range(random.randint(min_seglen, max_seglen))]
    seq = seq_before + seq_during + seq_after

    # Pad it up to max len with 0's
    seq = seq + ([0] * (length - len(seq)))
    return [seq, len(seq_before), len(seq_before) + len(seq_during)-1]


In [128]:
data = generate_nested_sequence(length=L)

In [129]:
def create_one_hot(l ,inp):
    return inp

In [130]:
train_segment_length_min = 4
train_segment_length_max = 6

batches = 32
train_batches = []
for batch in range(batches):
    
    seqs = []
    start_indices = []
    end_indices = []
    for i in range(B):
        seq, start, end = generate_nested_sequence(L, 
                                                    train_segment_length_min, 
                                                    train_segment_length_max)

        start_, end_ = create_one_hot(L, start),  create_one_hot(L, end)
        seqs.append(np.array(seq)), start_indices.append(start_), end_indices.append(end_)

    seqs          = np.stack(seqs)
    start_indices = np.array(start_indices)
    end_indices   = np.array(end_indices)
    indices = np.stack([start_indices, end_indices])
    train_batches.append(np.array([seqs, indices]))
                         
train_batches[0][1].shape


(2, 32)

In [131]:
train_batches[0][0]

array([[5, 4, 1, 5, 8, 7, 9, 7, 8, 4, 5, 2, 3, 1, 0, 0, 0, 0, 0, 0],
       [3, 3, 3, 2, 2, 8, 9, 7, 7, 9, 1, 3, 4, 4, 1, 0, 0, 0, 0, 0],
       [5, 4, 3, 2, 3, 9, 8, 6, 6, 9, 6, 5, 3, 4, 2, 0, 0, 0, 0, 0],
       [2, 3, 2, 4, 3, 3, 9, 8, 6, 8, 5, 5, 3, 3, 3, 3, 0, 0, 0, 0],
       [2, 3, 4, 5, 4, 3, 7, 6, 9, 9, 2, 3, 3, 3, 4, 0, 0, 0, 0, 0],
       [2, 1, 3, 1, 3, 6, 7, 7, 6, 3, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0],
       [3, 2, 3, 4, 8, 9, 9, 9, 9, 7, 3, 3, 1, 1, 3, 0, 0, 0, 0, 0],
       [5, 2, 2, 2, 4, 7, 6, 8, 8, 3, 1, 3, 1, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 4, 3, 3, 2, 7, 9, 9, 6, 6, 1, 3, 1, 2, 2, 4, 0, 0, 0],
       [1, 3, 4, 1, 6, 6, 8, 8, 2, 5, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0],
       [2, 1, 1, 1, 8, 9, 7, 8, 2, 4, 3, 2, 2, 5, 0, 0, 0, 0, 0, 0],
       [5, 3, 4, 1, 1, 8, 6, 6, 9, 6, 3, 3, 3, 2, 4, 0, 0, 0, 0, 0],
       [4, 4, 3, 4, 1, 9, 8, 7, 7, 3, 3, 4, 4, 5, 0, 0, 0, 0, 0, 0],
       [5, 4, 5, 5, 5, 6, 6, 9, 7, 7, 6, 2, 4, 4, 1, 2, 1, 0, 0, 0],
       [2, 4, 5, 5, 5, 9, 7, 8, 8,

In [111]:
import numpy as np

In [117]:
a = np.random.randint(0, 9, size=[2,5])
a

array([[8, 7, 8, 1, 7],
       [2, 2, 7, 1, 2]])

In [119]:
a.reshape([5,2])

array([[8, 7],
       [8, 1],
       [7, 2],
       [2, 7],
       [1, 2]])