# Learning seq2seq

I'm trying to learn how seq2seq works. I understand it in principle but need to figure out how to actually implement it in TensorFlow. At this point, the seq2seq library is a little too dense for me to follow line by line, so I'll be following a smaller implementation from [Illia Polosukhin](https://medium.com/@ilblackdragon/tensorflow-sequence-to-sequence-3d9d2e238084).

In [2]:
import logging
import tensorflow as tf
from tensorflow.contrib import layers

Some universals. Indices for start, end, and unknown.

In [1]:
GO_TOKEN = 0
END_TOKEN = 1
UNK_TOKEN = 2

Define the seq2seq model. I'll try annotating with comments as much as possible.

In [2]:
def seq2seq(mode, features, labels, params):
    vocab_size = params['vocab_size'] # the number of permitted vocabulary words
    embed_dim = params['embed_dim'] # the embedding dimension of the words
    num_units = params['num_units'] # the embedding dimension of the GRU cell
    input_max_length = params['input_max_length'] # the maximum accepted length for inputs
    output_max_length = params['output_max_length'] # the maximum accepted length for outputs
    
    inp = features['input'] # input sequences from passed parameter
    output = features['output'] # output sequences? how is it different from labels?
    batch_size = tf.shape(inp)[0] # infer the batch size from the first dimension of the inputs
    start_tokens = tf.zeros([batch_size], dtype=tf.int64) # make a vector for start
    # tf.expand_dims makes start_tokens go from dim batch_size to dim batch_size x 1
    # then we concat start_tokens to the front of output; axis 1 says concat by columns (hstack)
    train_output = tf.concat([tf.expand_dims(start_tokens, 1), output], 1)
    # tf.not_equal produces a tensor of same dim as train_input, True if != 1 (end_token), False otherwise
    # tf.to_int32 changes bool to 1s and 0s
    # tf.reduce_sum provides lengths
    input_lengths = tf.reduce_sum(tf.to_int32(tf.not_equal(inp, 1)), 1)
    output_lengths = tf.reduce_sum(tf.to_int32(tf.not_equal(train_output, 1)), 1)
    # takes the indices and converts each sentence to series of embeddings
    # scope indicates the variable scope of the op
    input_embed = layers.embed_sequence(inp, 
                                       vocab_size=vocab_size,
                                       embed_dim=embed_dim,
                                       scope='embed')
    output_embed = layers.embed_sequence(train_output, 
                                       vocab_size=vocab_size,
                                       embed_dim=embed_dim,
                                       scope='embed')
    # variable_scope is a context manager for the creation of variables
    # 'embed' is the name of the scope
    # reuse=True, means that the embeddings can be shared, reused
    # makes a variable called 'embeddings.' Not initialized.
    with tf.variable_scope('embed', reuse=True):
        embeddings = tf.get_variable('embeddings')
        
    # the cell of the encoder/decoder is GRU
    cell = tf.contrib.rnn.GRUCell(num_units=num_units)
    # this is the RNN, which provides outputs and the final state
    encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(cell, input_embed, dtype=tf.float32)
    
    # Still not sure what TrainingHelper does. It has something to do with decoding for producing outputs.
    train_helper = tf.contrib.seq2seq.TrainingHelper(output_embed, output_lengths)
    # This is the inference helper when there are no labels?
    pred_helper = tf.conrib.seq2seqGreedyEmbeddingHelper(embeddings, start_tokens=tf.to_int32(start_tokens), end_token=1)
    
    # function for actual decoding, during training or inference    
    def decode(helper, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            # using attention
            # need num_units in GRU state
            # the outputs of the encoder are the input to attention
            # also pass encoder input lengths
            attention_mechanism = tf.contrib.seq2seqBahdanauAttention(num_units=num_units, memory=encoder_outputs, memory_sequence_length=input_lenghts)
            # 