In [16]:
from IPython.display import Video

In [17]:
Video('https://jalammar.github.io/images/seq2seq_2.mp4', )

# An introduction to seq2seq models, Attenction and Transformers

This presentation is heavily inspired by [Jay Alammar](http://jalammar.github.io/) and [Christopher Olah blog](http://colah.github.io/).

## Introduction

* Sequence to sequence models are `Deep Learning` models used in many tasks
    * Machine Translation
    * Text Summarization
    * Text Generation
* Takes in a sequence of items, and outputs another sequence of items
    * Here we focus on words as input and output

**Here is how a trained seq2seq model works for the task of machine translation**

<video controls src="https://jalammar.github.io/images/seq2seq_2.mp4" alt="Seq2seq machine translation" width="80%"/>

## Digging the black box

The model is composed of an **encoder** and a **decoder**.

#### Encoder
* Takes each input item (word) one by one
* Processes them and captures their information
* Outputs a *Context* vector as its result of processing the entire input

#### Decoder
* Takes the *Context* vector as its input entirely
* Processes it and decode the information to fit into the desired output (another language for machine translation task)
* Outputs items (words) one by one

**Machine translation task, step by step**
<video controls src="https://jalammar.github.io/images/seq2seq_4.mp4" alt="Seq2seq machine translation step by step" width="80%"/>

* Context is a vector of numbers, representing the information captured by the encoder from the input
    * It's a matter of choice what size it has
* Both encoder and decoder are Recurrent Neural Networks under the hood
    * Introduced RNNs and specifically, LSTMs in previous series

**This is how the context vector look like**

<img src="https://jalammar.github.io/images/context.png" alt="Context Vector" width="80%"/>

### Word Embedding

We discussed word embedding methods `Word2Vec` and `GloVe` in the previous series of tutorials. To summarize, word embedding is used to convert words and sentences into numbers so that we could feed them to neural networks.

Seq2seq models and specificall, encoders are not exception and we should embed the document before we feed them to the network.

**This is how an embedded vector for that sentence looks like**


<img src="https://jalammar.github.io/images/embedding.png" alt="Embedded Vector" width="80%"/>

### Recap of RNN

<video controls src="https://jalammar.github.io/images/RNN_1.mp4" alt="RNNs step by step" width="80%"/>

1. Hidden state 0 and input vector 1 (current word) are fed to the RNN
2. The result of that would be hidden state 1 and output vector 1

The unrolled version of RNNs may help to understand their operation better

<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png" alt="Unrolled RNN" width="80%"/>


3. Similarly, the hidden state 1 and the input vector 2 (next word) are fed to the RNN
4. Hidden state 2 and output vector 2 are the outputs
5. This process continues until no further input is left

The math behind the scenes is a series of dot products and softmax:

<img src="https://datascience-enthusiast.com/figures/rnn_step_forward.png" alt="Behind the scenes RNN" width="80%"/>

* W vectors are the weights of the RNN to be trained and optimized
* X vector is the embedded word vector (input feature vector)
* a vector is the hidden state
* y vector is the output state
* t and t-1 shows current time step and previous time step, respectively

**Note that there is also a backpropagation process for the sake of training the network and adjusting weights, but we don't discuss them here**

### Back to encoder-decoder architecture

Now that we know how RNNs work, we can continue with the encoder-decoder network.

<video controls src="https://jalammar.github.io/images/seq2seq_5.mp4" alt="En-De step by step" width="80%"/>

At each pulse, the RNN in encoder or decoder is processing its input and generating the output and hidden state for that time step.

The hiddent states in the encoder RNNs keep propagating to the next ones, until they reach the last RNN in the encoder. The final hidden state vector, will be the `Context Vecror` that goes through the decoder as its input.

Now let's unroll the process even more.


<video controls src="https://jalammar.github.io/images/seq2seq_6.mp4" alt="En-De step by step unrolled" width="80%"/>

The decoder also works the same way as encoder, as it has a very similar architecture to encoder. However, it does not accept any input vector.

### encoder-decoder weakness and the concept of Attention

The `Context vector` tends to be the bottleneck for this model. In the case of long sentences, the number of words is more and when the time step comes to the later words, the hidden state has already forgotten about the earlier words as it propagates throughout the RNN cells.


#### Attention

Attention helps with the `context vector` bottleneck problem by providing context for **each word** rather than the whole sentence. This helps the decoder to focus on relevant and important parts of the encoded input data at each step of decoding.

So the **encoder** with attention sends more information to the decoder by providing **all** of the hiddent states.

The **decoder** with attention takes all of the hidden states and do the followings:
1. Process the hidden state for each word and gives it a score
2. Amplify the important hidden states for each time step and drown the less informative and less important hidden states

**Here is how encoder-decoder with attention works for the task of machine translation**

<video controls src="https://jalammar.github.io/images/seq2seq_7.mp4" alt="En-De with attention step by step" width="80%"/>


**Now let's see how the hidden states pass along decoder cells and how they are scored**

<video controls src="https://jalammar.github.io/images/attention_process.mp4" alt="Decoder with attention step by step" width="80%"/>


To summarize what happens in the decoder:
1. At each time step the previous decoder hidden state is fed to the decoder RNN cell (the decoder RNN input is always /<END/> as we don't have input in decoder)
2. The output of the RNN is calculated as new hidden state
3. The encoder hidden states are amplified based on their importance against the cell weights
4. The result of step 3 and 2 are concatenated to form the final decoder cell hidden state at that time step

**To visualize how the encoder hidden states are scored, let's look at this example**
<video controls src="https://jalammar.github.io/images/seq2seq_9.mp4" alt="Translation encoder hidden states scored" width="80%"/>

*Note that hidden states are not weighted based on their order, rather based on their importance which does not necessarily comply with the word order*

<img src="https://jalammar.github.io/images/attention_sentence.png" alt="Encoder hidden state amplification" width="80%"/>


#### Long-Short Term Memory Networks - LSTM

LSTMs are a variation of RNNs that improve the performance. Specifically, they help better preserving the context of previously seen words in future passes. We introduced them with more details in the previous series of tutorials.

We use LSTM here to implement a demo. We won't implement the attention mechanism for the sake of time.


The example here is heavily inspired by the content from the [Keras blog](https://blog.keras.io/)

In [1]:
# !pip install keras
# !pip install numpy
# !pip install tensorflow

from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as np

In [2]:
batch_size = 64  # Batch size for training.
epochs = 10  # Number of epochs to train for.
latent_dim = 256  # Latent dimensionality of the encoding space.
num_samples = 10000  # Number of samples to train on.
# Path to the data txt file on disk.
data_path = './data/fra-eng/fra.txt'

In [3]:
# Vectorize the data.
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
with open(data_path, 'r', encoding='utf-8') as f:
    lines = f.read().split('\n')
for line in lines[: min(num_samples, len(lines) - 1)]:
    input_text, target_text, _ = line.split('\t')
    # We use "tab" as the "start sequence" character
    # for the targets, and "\n" as "end sequence" character.
    target_text = '\t' + target_text + '\n'
    input_texts.append(input_text)
    target_texts.append(target_text)
    for char in input_text:
        if char not in input_characters:
            input_characters.add(char)
    for char in target_text:
        if char not in target_characters:
            target_characters.add(char)

In [4]:
input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])

print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)

Number of samples: 10000
Number of unique input tokens: 71
Number of unique output tokens: 93
Max sequence length for inputs: 15
Max sequence length for outputs: 59


In [5]:
input_token_index = dict(
    [(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict(
    [(char, i) for i, char in enumerate(target_characters)])

encoder_input_data = np.zeros(
    (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
decoder_input_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')
decoder_target_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')

In [6]:
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t, input_token_index[char]] = 1.
    encoder_input_data[i, t + 1:, input_token_index[' ']] = 1.
    for t, char in enumerate(target_text):
        # decoder_target_data is ahead of decoder_input_data by one timestep
        decoder_input_data[i, t, target_token_index[char]] = 1.
        if t > 0:
            # decoder_target_data will be ahead by one timestep
            # and will not include the start character.
            decoder_target_data[i, t - 1, target_token_index[char]] = 1.
    decoder_input_data[i, t + 1:, target_token_index[' ']] = 1.
    decoder_target_data[i, t:, target_token_index[' ']] = 1.
# Define an input sequence and process it.
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

In [7]:
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None, num_decoder_tokens))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                     initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

In [8]:
# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Run training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
          batch_size=batch_size,
          epochs=epochs,
          validation_split=0.2)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fc1e4326eb0>

In [9]:
# Save model
model.save('./data/s2s.h5')

In [10]:
# Next: inference mode (sampling).
# Here's the drill:
# 1) encode input and retrieve initial decoder state
# 2) run one step of decoder with this initial state
# and a "start of sequence" token as target.
# Output will be the next target token
# 3) Repeat with the current target token and current states

In [11]:
# Define sampling models
encoder_model = Model(encoder_inputs, encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)

In [12]:
# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())


In [13]:
def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence

In [14]:
for seq_index in range(100):
    # Take one sequence (part of the training set)
    # for trying out decoding.
    input_seq = encoder_input_data[seq_index: seq_index + 1]
    decoded_sentence = decode_sequence(input_seq)
    print('-')
    print('Input sentence:', input_texts[seq_index])
    print('Decoded sentence:', decoded_sentence)

-
Input sentence: Go.
Decoded sentence: Restez à la mais.

-
Input sentence: Hi.
Decoded sentence: Restez la moi.

-
Input sentence: Hi.
Decoded sentence: Restez la moi.

-
Input sentence: Run!
Decoded sentence: Laissez-moi !

-
Input sentence: Run!
Decoded sentence: Laissez-moi !

-
Input sentence: Who?
Decoded sentence: Qui est alle ?

-
Input sentence: Wow!
Decoded sentence: Fais son en aite.

-
Input sentence: Fire!
Decoded sentence: Attends un chanter !

-
Input sentence: Help!
Decoded sentence: Restez !

-
Input sentence: Jump.
Decoded sentence: Restez à l'aire.

-
Input sentence: Stop!
Decoded sentence: Restez !

-
Input sentence: Stop!
Decoded sentence: Restez !

-
Input sentence: Stop!
Decoded sentence: Restez !

-
Input sentence: Wait!
Decoded sentence: Restez à la mainon.

-
Input sentence: Wait!
Decoded sentence: Restez à la mainon.

-
Input sentence: Go on.
Decoded sentence: Restez à l'aire.

-
Input sentence: Go on.
Decoded sentence: Restez à l'aire.

-
Input sentence: Go