#### seq2seq example run through

* SRC from [higepon](https://gist.github.com/higepon/eb81ba0f6663a57ff1908442ce753084) to get to grips with the API.
* TF guide for NMT [here](https://www.tensorflow.org/tutorials/seq2seq)
* Documentation [here](https://www.tensorflow.org/api_guides/python/contrib.seq2seq)

In [1]:
# Import statements
import numpy as np
import tensorflow as tf
seq2seq = tf.contrib.seq2seq

  return f(*args, **kwds)


In [2]:
# Parameter setup
hparams = tf.contrib.training.HParams(batch_size=3,
                                      encoder_length=4,
                                      decoder_length=5,
                                      num_units=6,
                                      src_vocab_size=7,
                                      tgt_vocab_size=9,
                                      embedding_size=8,
                                      optimizer="Adam",
                                      learning_rate=0.01,
                                      max_gradient_norm=5.0,
                                      beam_width=9,
                                      use_attention=False,
                                     )

# Symbol for starting to decode
tgt_sos_id = 7

# Symbol for end of decode process
tgt_eos_id = 8

# Refresh graph
tf.reset_default_graph()

# Variable initialiser
initializer = tf.contrib.layers.xavier_initializer()

### Encoder declaration

In [3]:
# Encoder input placeholder is time major [enc_len, batch_size]
enc_inputs = tf.placeholder(dtype=tf.int32, shape=(hparams.encoder_length, hparams.batch_size), name="enc_input")

# Encoder embedding matrix [can be initialised with GloVe] [src_vocab_size, embedding_size]
enc_embeddings = tf.get_variable(shape=(hparams.src_vocab_size, hparams.embedding_size), name="enc_embed")

# Lookup enc_input in the embedding matrix [enc_len, batch_size, embedding_size]
enc_input_embs = tf.nn.embedding_lookup(params=enc_embeddings, ids=enc_inputs)

# LSTM encoder
lstm_encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)

##### LSTM Computation graph for Encoding

In [4]:
# Run Dynamic RNN: outputs [enc_len, batch_size, num_units]: enc_state [batch_size, num_units]
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(cell=lstm_encoder_cell, 
                                                   inputs=enc_input_embs, 
                                                   time_major=True,
                                                   dtype=tf.float32
                                                  )

### Decoder declaration

In [5]:
# Decoder input placeholder is time major [dec_len, batch_size]
# Decoder lengths [batch_size]

# For training we use 100% teacher forcing so each time step uses Ground Truth inputs. Inference is 100% free running
dec_inputs = tf.placeholder(dtype=tf.int32, shape=(hparams.decoder_length, hparams.batch_size),name="dec_inputs")
dec_lens = tf.placeholder(dtype=tf.int32, shape=(hparams.batch_size), name="dec_len")

# Decoder embedding matrix [can be initialised with GloVe] [tgt_vocab_size, embedding_size]
dec_embeddings = tf.get_variable(shape=(hparams.tgt_vocab_size, hparams.embedding_size), name="dec_embed")

# Lookup dec_inputs 
dec_input_embs = tf.nn.embedding_lookup(params=dec_embeddings, ids=dec_inputs)

# LSTM decoder
lstm_decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)

# Projection layer between LSTM outputs and vocab size for softmax learning model
projection_layer = tf.layers.Dense(units=hparams.tgt_vocab_size, activation=None, kernel_initializer=initializer, use_bias=False)

##### seq2seq decoding declaration

In [6]:
# Declare decoding helper
dec_helper = seq2seq.TrainingHelper(inputs=dec_input_embs, sequence_length=dec_lens, time_major=True, name="dec_helper")

if hparams.use_attention:
    # Attention states are batch major [batch_size, max_time, num_units]
    attention_states = tf.transpose(encoder_outpus, [1, 0, 2])

    luong_attn_mechanism = seq2seq.LuongAttention(num_units=hparams.num_units, 
                                                 memory=attention_states, 
                                                 memory_sequence_length=None
                                                 )

    decoder_cell = seq2seq.AttentionWrapper(cell=lstm_decoder_cell,
                                            attention_mechanism=luong_attn_mechanism,
                                            attention_layer_size=hparams.num_units
                                            )

    dec_initial_state = lstm_decoder_cell.zero_state(hparams.batch_size, tf.float32).clone(cell_state=encoder_state)
else:
    # Decoding with dec_helper when encoder output feeds to decoder input
    dec_initial_state = encoder_state

decoder = seq2seq.BasicDecoder(cell=lstm_decoder_cell,
                               helper=dec_helper,
                               initial_state=dec_initial_state,
                               output_layer=projection_layer)

##### LSTM Computation graph for Decoding with seq2seq helpers

In [7]:
# s2s version of decoding
# final_outputs.rnn_output : RNN states [batch_size, dec_len, tgt_vocab_size]
# final_outputs.sample_id  : Argmax for IDs from RNN outputs [batch_size, dec_len]
# final_state              : Final RNN states [batch_size, num_units]
# final_seq_lens           : Decoded sequences [batch_size, dec_len]

final_outputs, _final_state, _final_seq_lens = seq2seq.dynamic_decode(decoder)

output_logits = final_outputs.rnn_output

In [8]:
print("rnn_output.shape=", final_outputs.rnn_output.shape)
print("sample_id.shape=", final_outputs.sample_id.shape)
print("final_state=", _final_state)
print("final_sequence_lengths.shape=", _final_seq_lens.shape)

rnn_output.shape= (3, ?, 9)
sample_id.shape= (3, ?)
final_state= LSTMStateTuple(c=<tf.Tensor 'decoder/while/Exit_3:0' shape=(3, 6) dtype=float32>, h=<tf.Tensor 'decoder/while/Exit_4:0' shape=(3, 6) dtype=float32>)
final_sequence_lengths.shape= (3,)


##### Setup loss and training operations

In [9]:
# Training targets for decoder [batch_size, dec_len]
dec_targets = tf.placeholder(dtype=tf.int32, shape=(hparams.batch_size, hparams.decoder_length))

# Setup loss computation [could also use seq2seq.sequence_loss]: https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/sequence_loss
smxe_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=dec_targets,
                                                           logits=output_logits,
                                                           name="smxe_loss")

# Reduce loss in batch by mean
loss_ = tf.reduce_mean(smxe_loss)

# Training operations
global_step = tf.Variable(initial_value=0, 
                          name="global_step", 
                          trainable=False, 
                          collections=[tf.GraphKeys.GLOBAL_STEP,
                                       tf.GraphKeys.GLOBAL_VARIABLES]
                         )

train_op = tf.contrib.layers.optimize_loss(loss=loss_,
                                           global_step=global_step,
                                           learning_rate=hparams.learning_rate,
                                           optimizer=hparams.optimizer,
                                           clip_gradients=hparams.max_gradient_norm,
                                          )

##### TF Interactive session

In [10]:
sess = tf.InteractiveSession()

Example training data

In [11]:
# Encoder 

tweet1 = np.array([1, 2, 3, 4])
tweet2 = np.array([0, 5, 6, 3])

# Make batch data [TIME MAJOR]
train_encoder_inputs = np.stack((tweet1, tweet2, tweet1), axis=0).astype(np.int32).T

print("Tweets")
print(train_encoder_inputs)

Tweets
[[1 0 1]
 [2 5 2]
 [3 6 3]
 [4 3 4]]


In [12]:
# Decoder

training_decoder_input1 = [tgt_sos_id, 2, 3, 4, 5]
training_decoder_input2 = [tgt_sos_id, 5, 6, 4, 3]

training_target_label1 = [2, 3, 4, 5, tgt_eos_id]
training_target_label2 = [5, 6, 4, 3, tgt_eos_id]

training_target_labels = np.stack((training_target_label1, 
                                    training_target_label1, 
                                    training_target_label1), axis=0).astype(np.int32)
print("Replies")
print(training_target_labels)

training_decoder_inputs = np.stack((training_decoder_input1, 
                                    training_decoder_input2, 
                                    training_decoder_input1), axis=0).astype(np.int32).T
print("Inputs")
print(training_decoder_inputs)

training_decoder_lens = np.ones((hparams.batch_size), dtype=int) * hparams.decoder_length

print("Decoder lengths")
print(training_decoder_lens)

Replies
[[2 3 4 5 8]
 [2 3 4 5 8]
 [2 3 4 5 8]]
Inputs
[[7 7 7]
 [2 5 2]
 [3 6 3]
 [4 4 4]
 [5 3 5]]
Decoder lengths
[5 5 5]


##### Training

In [13]:
sess.run(tf.global_variables_initializer())

In [14]:
for i in range(100):
    _, loss_val = sess.run(fetches=[train_op, loss_],
                           feed_dict={
                               enc_inputs: train_encoder_inputs,
                               dec_targets: training_target_labels,
                               dec_inputs: training_decoder_inputs,
                               dec_lens: training_decoder_lens
                           })
    print(loss_val)

2.22891
2.20544
2.18285
2.16054
2.13812
2.11518
2.0913
2.06614
2.03949
2.01127
1.98148
1.9502
1.91751
1.88354
1.84836
1.81212
1.77494
1.73703
1.69861
1.65991
1.62113
1.58244
1.54393
1.50564
1.46761
1.42981
1.39218
1.35454
1.31654
1.27778
1.2382
1.19829
1.15907
1.12136
1.08512
1.0497
1.01487
0.981253
0.949387
0.918887
0.888829
0.858923
0.829603
0.801231
0.773712
0.746931
0.721111
0.696525
0.673082
0.650437
0.628419
0.607146
0.586745
0.567149
0.548197
0.529822
0.512077
0.495011
0.47859
0.462743
0.447444
0.432717
0.418585
0.405034
0.392023
0.379509
0.367469
0.355887
0.344744
0.334017
0.323676
0.313698
0.304065
0.294767
0.285795
0.277137
0.268782
0.260721
0.252945
0.245448
0.238225
0.231271
0.224578
0.218138
0.211939
0.205969
0.200217
0.194673
0.189328
0.184177
0.179212
0.17443
0.169825
0.165391
0.161123
0.157012
0.15305
0.149229
0.14554
0.141976


#### Model is now trained, test with inference on toy task

In [15]:
# Inference helper. Similar to training helper but free running over teacher forcing
tiled_sos_id = tf.fill([hparams.batch_size], tgt_sos_id)

inference_helper = seq2seq.GreedyEmbeddingHelper(embedding=dec_embeddings,
                                                 start_tokens=tiled_sos_id,
                                                 end_token=tgt_eos_id)

# Inference Decoder runs the decoding process with free running inputs
inference_decoder = seq2seq.BasicDecoder(cell=lstm_decoder_cell,
                                         helper=inference_helper,
                                         initial_state=dec_initial_state,
                                         output_layer=projection_layer)

# Limit realistic RNN length
maxit = tf.round(tf.reduce_max(hparams.encoder_length) * 2)

outputs, _, _ = seq2seq.dynamic_decode(inference_decoder, maximum_iterations=maxit)
translations = outputs.sample_id


In [16]:
# Input tweets
inference_encoder_inputs = train_encoder_inputs

In [17]:
replies = sess.run(fetches=[translations], 
                   feed_dict={
                       enc_inputs: inference_encoder_inputs
                   })

print(replies)

[array([[2, 3, 4, 5, 8],
       [2, 3, 4, 5, 8],
       [2, 3, 4, 5, 8]], dtype=int32)]


##### Well thats only 66% correct, lets try a Beam search


In [18]:
# Beam Search
# Replicate encoder infos beam_width times
decoder_initial_state = tf.contrib.seq2seq.tile_batch(
    dec_initial_state, multiplier=hparams.beam_width)

# Define a beam-search decoder
inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
        cell=lstm_decoder_cell,
        embedding=dec_embeddings,
        start_tokens=tf.fill([hparams.batch_size], tgt_sos_id),
        end_token=tgt_eos_id,
        initial_state=decoder_initial_state,
        beam_width=hparams.beam_width,
        output_layer=projection_layer,
        length_penalty_weight=0.0)

# Dynamic decoding
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
    inference_decoder, maximum_iterations=maxit)

# Convert to [batch_size, seq_len, beam_size]
translations = tf.transpose(outputs.predicted_ids,[0, 2, 1]) 

In [19]:
replies = sess.run(translations, feed_dict={
                       enc_inputs: inference_encoder_inputs,
                       dec_lens: training_decoder_lens
                   })

print(replies)

[[[2 3 4 5 8 8]
  [2 3 5 8 8 8]
  [2 3 4 4 5 8]
  [2 3 4 8 8 8]
  [2 3 4 4 8 8]
  [2 3 4 5 5 8]
  [2 3 8 8 8 8]
  [2 2 3 4 5 8]
  [2 3 4 4 4 8]]

 [[2 3 4 5 8 8]
  [2 3 5 8 8 8]
  [2 3 4 4 5 8]
  [2 3 4 8 8 8]
  [2 3 4 4 8 8]
  [2 3 4 5 5 8]
  [2 3 8 8 8 8]
  [2 2 3 4 5 8]
  [2 3 4 4 4 8]]

 [[2 3 4 5 8 8]
  [2 3 5 8 8 8]
  [2 3 4 4 5 8]
  [2 3 4 8 8 8]
  [2 3 4 4 8 8]
  [2 3 4 5 5 8]
  [2 3 8 8 8 8]
  [2 2 3 4 5 8]
  [2 3 4 4 4 8]]]


In [20]:
sess.close()