In [None]:
# default_exp attention.attention

# attention
### Introduction ### 
I this blog post I wanna to introduce anttention. So, I'm gonna deviate from what I've seen as the usual introduction to anttention. I think it's important to see things in historical context. I'm gonna try and explain the paper which introduced anttention, the context in which it was introduced. Further, where most tutorials focus on the vector representation of the anttention mechanism I'm gonna show in the fewest possible lines of code the core of the anttention mechanism. 


#### Neural Machine Translation by Jointly Learning to Align and Translate #### 

This paper introduces anttention in the context of sequence-to-sequence modelling. The paper describes the status of NMT models and the problem the authors are trying to solve with anttention: 



The paper states that an issue with the encoder-decoder approach is that

`the last hidden state that the encoder outputs and is used for the translation cannot contain all the information necessary for the entire translation.`

So, understanding this forms the core of the problem is trying to solve and it's worth taking a minute to really understand what's going on. 



The paper introduces a novel mechanism to bring about improvements to the task of sequence to sequence modelling. I think it's worth taking a minute to really drill down on the setup, the problem and then delve into how anttention is gonna solve that problem. 

![](assets/seq2seq1.png)

If you're not aware of how an RNN functions you need to stop here understand first. If you're not comfortable with how a RNN functions things are not going to be as clear. As the rest of you know an RNN processes a sequence with a time delay. After processing each new token in the input it creates a hidden state which is a vector representation of the sequence till the last input. 

So, from the diagram `h3` represents the sequence till `morgen`. 

The seq-to-seq setup is usually used for translation between languages. So, the last hidden state which represents the entire sequence becomes the initial hidden state for another RNN that is going to be responsible for outputting the translated sequence.  

This is the basic setup for a seq-to-seq model. Now, let's take a look at the paper describing the possibe problem with this steup.  


> Most  of  the  proposed  neural  machine  translation  models  belong  to  a  family  of encoder–decoders, with an encoder and a decoder for each language, or involve a language-specific encoder applied to each sentence whose outputs are then compared.  An encoder neural network reads and encodes a source sen-tence into a fixed-length vector. A decoder then outputs a translation from the encoded vector. The whole encoder–decoder system, which consists of the encoder and the decoder for a language pair,is jointly trained to maximize the probability of a correct translation given a source sentence.

>A potential issue with this encoder–decoder approach is that a neural network needs to be able to compress all the necessary information of a source sentence into a fixed-length vector.  This may make it difficult for the neural network to cope with long sentences, especially those that are longer than the sentences in the training corpus. Choet al.(2014b) showed that indeed the performance of a basic encoder–decoder deteriorates rapidly as the length of an input sentence increases.

I want highlight what I feel is the important part of this description: 

`A potential issue with this encoder–decoder approach is that a neural network needs to be able to compress all the necessary information of a source sentence into a fixed-length vector.`

What the paper is saying is that the vector `z` might not contain enough information to remember everything it has been fed and might forget what it was fed in the earlier part of the sentence. So, taking the example given in the diagram if `z` is small enough it might not remember `guten` after consuming `morgen`. The size of the vector `z` is in our hands remember. Though I have not seen any studies (I'm sure they exist) comparing the memory capacity of a vector with the size of the vector. For all purposes a vector of size 10 will be able to remember this sequence but as sequences get larger we will hit a ceiling. 

So, we want to (when called upon) remember whichever parts of the input sequence that we like. Well, as mentioned earlier we have the hidden state for every token we processed. If we want to use all the hidden states that are available to us we need answer few basic questions: 
1. How do we know which hidden state to pass to the decoder? 
2. If I need information from multiple hidden states how do I go about using them? 

#### Enter: Anttention #### 

When we're using anttention, the first input that the decoder recieves is not the last hidden state of the encoder RNN but the <start> token for the translation. Using this token I want to extract information from the given hidden states of the encoder. Given the <stat> token we know we want information for the initial set of hidden states. It might be the first few hidden states even. 
    
Here, I would say we're ready to look at how the anttention mechanism functions as we have descibed in plenty detail the two inputs anttention requires: 
1. A list of vectors from which I want to extract some information. 
2. A vector which is gonna help me gauage what information I want to extract. 
    


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
ENC_HID_DIM = 512
DEC_HID_DIM = 512

In [None]:
attn = nn.Linear((ENC_HID_DIM * 2) + DEC_HID_DIM, DEC_HID_DIM)
v = nn.Linear(DEC_HID_DIM, 1, bias = False)


Here I'm going to create an artifical torch vectors to step through the mechanism. I would encourage you to go through the complete implementation of the paper once you're done here. I am attaching a notebook which as the complete implementation. 

In [None]:

# In the first step, this is gonna be the <start> token
# Shape: (batch_size,decoder_hidden_dimension)
# batch_size: Num of sequences we're processing
# decoder_hidden_dimension: hidden state dimension of the decoder RNN
hidden = torch.randn(1,512)

# List of vectors (Encoder Outputs)
# Here the shape is (src_len,batch_size,encoder_hidden_dimension)
# src_len: Length of the source sequence
# batch_size: number of sequences I'm processing in one go
# encoder_hidden_dimension: size of the hidden states in the encoder
encoder_outputs = torch.randn(5, 1, 1024)

In [None]:


# This is what the actual processing is gonna look like. I have taken this code from the `forward` method of 
# the anttention class. I'm going to be printing the dimensions/ values of vaiables to keep track of what's
# happening

In [None]:
batch_size = encoder_outputs.shape[1]
src_len = encoder_outputs.shape[0]
print(batch_size)
print(src_len)

1
5


In [None]:
#repeat decoder hidden state src_len times

# I'm going to perform an operation between the <start> token and each encoder outputs. I want to see how much
# information does each vector contain which will help me predict the next output. Thus, to perform such an 
# operation I'm creating duplicates of the <start> token for each encoder output. 

print(hidden.shape)
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
print(hidden.shape)

torch.Size([1, 512])
torch.Size([1, 5, 512])


In [None]:
# This is just re-arranging the encoder outputs so for my operation 
print(encoder_outputs.shape)
encoder_outputs = encoder_outputs.permute(1, 0, 2)
print(encoder_outputs.shape)

torch.Size([5, 1, 1024])
torch.Size([1, 5, 1024])


In [None]:
# This is the first step of attention. I'm concatting my <start> token vector and one encoder dimension 
# vector. Gonna do this with every encoder output vector. 

print(hidden.shape)
print(encoder_outputs.shape)
concat_hidden_encoder = torch.cat((hidden, encoder_outputs), dim = 2)
print(concat_hidden_encoder.shape)

torch.Size([1, 5, 512])
torch.Size([1, 5, 1024])
torch.Size([1, 5, 1536])


In [None]:
# `attn` is just a linear transformation. This will convert my vectors of size 1536 to 512. Simple 
# Matix Multiplication. 
energy = torch.tanh(attn(concat_hidden_encoder)) 
print(energy.shape)

torch.Size([1, 5, 512])


In [None]:
# Another Linear Transformation, converting vectors of size 512 to 1. 
v(energy).shape

torch.Size([1, 5, 1])

In [None]:
attention = v(energy).squeeze(2)
print(attention.shape)

torch.Size([1, 5])


In [None]:
# In the variable attention I now have a score which is a measure how useful each encoder output vector will 
# be in determining the next output. Lastly, I want to normalize the scores. I'm gonna do this via softmax. 
# This is the same softmax which is used in classification. 
encoder_softmax = F.softmax(attention, dim=1)

In [None]:
# You can see this is a like the classification probability distribution which sums up to 1. Here, 
# they can be seen as weights describing the importance of each encoder outut vector. 
encoder_softmax[0]

tensor([0.2748, 0.1554, 0.1157, 0.1979, 0.2562], grad_fn=<SelectBackward>)

In [None]:
# Rearranging
encoder_softmax = encoder_softmax.unsqueeze(1)

In [None]:
encoder_softmax.shape

torch.Size([1, 1, 5])

In [None]:
# This is what I was after. The softmaxxed vector gave the the weighatge of each encoder output in determining
# the next output. It's found that when I multiply each encoder output with its weightage and add them all up
# I get the information I was looking for from the initial set of encoder outputs. 

# This particular operation is in the Decoder of the Seq-to-Seq model. 
weighted = torch.bmm(encoder_softmax, encoder_outputs)

In [None]:
encoder_outputs.shape

torch.Size([1, 5, 1024])

In [None]:
weighted.shape

torch.Size([1, 1, 1024])

In [None]:
# Putting it all together, this is what the attention class looks like: 

class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
#         import pdb;pdb.set_trace()
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        #repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        #hidden = [batch size, src len, dec hid dim]
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        
        #energy = [batch size, src len, dec hid dim]

        attention = self.v(energy).squeeze(2)
        
        #attention= [batch size, src len]
        
        return F.softmax(attention, dim=1)

Here I have tried to cover the very core of the attention mechanism. Before moving forward I would highly reccomend to go through the entire implementaton here. I have added several debuggers (pdb) in the Encoder, Decoder, Attention and the final model. 

To summarise what attention does, given the input of a list of vectors and another vector, attention gives a weightage which signifies the dependence of the vector on each vector in the list of vectors. 