In [1]:
import sys
sys.path.insert(0, "/work/04233/sw33286/AIDA-SCRIPTS")

In [2]:
import dill
import random
import nltk
import numpy as np

import tensorflow as tf
import tensorflow.contrib.seq2seq as seq2seq
from tensorflow.contrib.layers import safe_embedding_lookup_sparse as embedding_lookup_unique
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple, GRUCell

from helpers import Indexer
from helpers import batch
# from nltk.corpus import brown
from collections import defaultdict, Counter

### Prepare data

In [3]:
VOCAB = ['one','two','three','four','five',
         'six','seven','eight','nine','ten']
indexer = Indexer()
indexer.get_index('PAD')
indexer.get_index('EOS')
for word in VOCAB:
    indexer.get_index(word)

In [19]:
FROM_LEN = 3
TO_LEN = 8

def generate_sent(from_len, to_len):
    length = np.random.randint(from_len, to_len)
    return np.random.choice(VOCAB, length)

def to_code(sent):
    return [indexer.get_index(word) for word in sent]

def to_sent(code):
    return list(map(lambda w_idx:indexer.get_object(w_idx), code))
    
def get_batch(n, from_len=FROM_LEN, to_len=TO_LEN):
    return [to_code(generate_sent(from_len,to_len)) for _ in range(n)]

### Bi-LSTM Autoencoder

In [22]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

batch_size = 10
vocab_size = len(indexer)
embed_size = 20
encoder_hidden_size = 10
decoder_hidden_size = encoder_hidden_size*2

# attention = True
# bidirectional = True

encoder_cell = LSTMCell(encoder_hidden_size)
decoder_cell = LSTMCell(decoder_hidden_size)

encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs')
encoder_inputs_length = tf.placeholder(tf.int32, [None,], name='encoder_inputs_length')
decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets')
decoder_targets_length = tf.placeholder(tf.int32, [None,], name='decoder_targets_length')

with tf.name_scope('DecoderTrainFeeds'):
    sequence_size, batch_size_ = tf.unstack(tf.shape(decoder_targets))
    EOS_SLICE = tf.ones([1, batch_size_], dtype=tf.int32) * indexer.get_index('EOS')
    PAD_SLICE = tf.ones([1, batch_size_], dtype=tf.int32) * indexer.get_index('PAD')
    decoder_train_inputs = tf.concat([EOS_SLICE, decoder_targets], axis=0) # [max-time+1, batch_size]
    decoder_train_length = decoder_targets_length + 1
    decoder_train_targets = tf.concat([decoder_targets, PAD_SLICE], axis=0)
    decoder_train_targets_seq_len,_ = tf.unstack(tf.shape(decoder_train_targets))
    decoder_train_targets_eos_mask = tf.one_hot(decoder_train_length-1,
                                                decoder_train_targets_seq_len,
                                                on_value=indexer.get_index('EOS'), 
                                                off_value=indexer.get_index('PAD'),
                                                dtype=tf.int32)
    decoder_train_targets_eos_mask = tf.transpose(decoder_train_targets_eos_mask, [1,0])
    decoder_train_targets = tf.add(decoder_train_targets,
                                   decoder_train_targets_eos_mask)
    loss_weights = tf.ones([
        batch_size,
        tf.reduce_max(decoder_train_length)
    ], dtype=tf.float32, name='loss_weights')
    
with tf.variable_scope('embedding') as scope:
    embedding_matrix = tf.get_variable('embedding_matrix', [vocab_size, embed_size], 
                                       initializer=tf.contrib.layers.xavier_initializer())
#     glove_feed = tf.placeholder(tf.float32, glove_embs.shape)
#     glove_init = embedding_matrix.assign(glove_feed)
    encoder_inputs_embedded = tf.nn.embedding_lookup(embedding_matrix, encoder_inputs)
    decoder_train_inputs_embedded = tf.nn.embedding_lookup(embedding_matrix, decoder_train_inputs)
    
with tf.variable_scope('BidirectionalEncoder') as scope:
    encoder_cell = LSTMCell(encoder_hidden_size)
    ((encoder_fw_outputs,encoder_bw_outputs),
     (encoder_fw_state,encoder_bw_state)) = (
        tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                        cell_bw=encoder_cell,
                                        inputs=encoder_inputs_embedded,
                                        sequence_length=encoder_inputs_length,
                                        dtype=tf.float32, time_major=True)
    )
    encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)
    if isinstance(encoder_fw_state, LSTMStateTuple):
        encoder_state_c = tf.concat((encoder_fw_state.c, encoder_bw_state.c), 1, name='bidirectional_concat_c')
        encoder_state_h = tf.concat((encoder_fw_state.h, encoder_bw_state.h), 1, name='bidirectional_concat_h')
        encoder_state = LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
    elif isinstance(encoder_fw_state, tf.Tensor):
        encoder_state = tf.concat((encoder_fw_state, encoder_bw_state), 1, name='bidirectional_concat')
        
with tf.variable_scope('Decoder') as scope:
    def output_fn(outputs):
        return tf.contrib.layers.linear(outputs, vocab_size, scope=scope)
    attention_states = tf.transpose(encoder_outputs, [1,0,2]) # [batch_size,max-time,hidden_size]
    (attention_keys,
     attention_values,
     attention_score_fn,
     attention_construct_fn) = seq2seq.prepare_attention(
        attention_states=attention_states,
        attention_option='bahdanau',
        num_units=decoder_hidden_size
    )
    decoder_fn_train = seq2seq.attention_decoder_fn_train(
        encoder_state=encoder_state,
        attention_keys=attention_keys,
        attention_values=attention_values,
        attention_score_fn=attention_score_fn,
        attention_construct_fn=attention_construct_fn,
        name='attention_decoder'
    )
    decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
        output_fn=output_fn,
        encoder_state=encoder_state,
        attention_keys=attention_keys,
        attention_values=attention_values,
        attention_score_fn=attention_score_fn,
        attention_construct_fn=attention_construct_fn,
        embeddings=embedding_matrix,
        start_of_sequence_id=indexer.get_index('EOS'),
        end_of_sequence_id=indexer.get_index('EOS'),
        maximum_length=tf.reduce_max(encoder_inputs_length) + 3,
        num_decoder_symbols=vocab_size
    )
    (decoder_outputs_train,
     decoder_state_train,
     decoder_context_state_train) = (
        seq2seq.dynamic_rnn_decoder(
            cell=decoder_cell,
            decoder_fn=decoder_fn_train,
            inputs=decoder_train_inputs_embedded,
            sequence_length=decoder_train_length,
            time_major=True,
            scope=scope
        )
    )
    decoder_logits_train = output_fn(decoder_outputs_train)
    decoder_prediction_train = tf.argmax(decoder_logits_train, axis=-1, name='decoder_prediction_train')
    scope.reuse_variables()
    (decoder_logits_inference,
     decoder_state_inference,
     decoder_context_state_inference) = (
        seq2seq.dynamic_rnn_decoder(
            cell=decoder_cell,
            decoder_fn=decoder_fn_inference,
            time_major=True,
            scope=scope
        )
    )
    decoder_prediction_inference = tf.argmax(decoder_logits_inference, axis=-1, name='decoder_prediction_inference')
    
logits = tf.transpose(decoder_logits_train, [1,0,2])
targets = tf.transpose(decoder_train_targets, [1,0])
loss = seq2seq.sequence_loss(logits=logits, targets=targets, weights=loss_weights)
train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)

sess.run(tf.global_variables_initializer())

In [25]:
def make_train_inputs(input_seq, target_seq):
    inputs_, inputs_length_ = batch(input_seq)
    targets_, targets_length_ = batch(target_seq)
    return {
        encoder_inputs: inputs_,
        encoder_inputs_length: inputs_length_,
        decoder_targets: targets_,
        decoder_targets_length: targets_length_,
    }

In [28]:
loss_track = []
num_batches = 10000
verbose = 1000
try:
    for b in range(num_batches):
        batch_inputs = get_batch(batch_size)
        fd = make_train_inputs(batch_inputs, batch_inputs)
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)
        if b==0 or b%verbose==0:
            print('batch {}'.format(b))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            for i,(e_in, dt_pred) in enumerate(zip(
                fd[encoder_inputs].T, # [max-time,batch_size] -> [batch_size,max-time]
                sess.run(decoder_prediction_train, fd).T
            )):
                print('  sample {}:'.format(i+1))
                print('    enc input           > {}'.format([w for w in to_sent(e_in) if w!='PAD']))
                print('    dec train predicted > {}'.format([w for w in to_sent(dt_pred) if w!='PAD']))
                if i>=2:
                    break
            print('\n')
except KeyboardInterrupt:
    print('training interrupted')

batch 0
  minibatch loss: 0.4709915220737457
  sample 1:
    enc input           > ['seven', 'two', 'two', 'one', 'eight', 'five']
    dec train predicted > ['seven', 'two', 'two', 'one', 'eight', 'five', 'EOS']
  sample 2:
    enc input           > ['five', 'nine', 'seven', 'eight', 'one', 'six']
    dec train predicted > ['five', 'nine', 'seven', 'eight', 'one', 'six', 'EOS']
  sample 3:
    enc input           > ['eight', 'one', 'nine']
    dec train predicted > ['eight', 'one', 'nine', 'EOS']


batch 1000
  minibatch loss: 0.4195615351200104
  sample 1:
    enc input           > ['four', 'seven', 'nine', 'ten']
    dec train predicted > ['four', 'seven', 'nine', 'ten', 'EOS']
  sample 2:
    enc input           > ['one', 'nine', 'four']
    dec train predicted > ['one', 'nine', 'four', 'EOS']
  sample 3:
    enc input           > ['five', 'five', 'one', 'eight', 'six', 'one']
    dec train predicted > ['five', 'one', 'one', 'eight', 'six', 'one', 'EOS']


batch 2000
  minibatch los