# The Attention Mechanism   

## Introduction ## 


Hello Everyone!

Welcome to DeepLearningDots. This is the first video in a series about the Attention Mechanism. So, in this video we're not going to be taking a look at the internals of the attention mechanism but what we're going to try to do is really hammer home the intuition behind the attention mechanism. 

We're going to talk about what the attention mechanism does and what the output looks like. 

______________________________________________________________________________________________________________

![](assets/Caption_Generator_Attention.png)

Now, before we look at some code let's try to use some visuals to get a feel of what the attention mechanism does. 

Here, we have the results of a model which recieves as input an image and as output returns a caption for it. This model uses attention in the process of generating the caption. After recieving the image as input and before predicting each word the attention mechanism is being used to tell the model

`Ok, to predict the next word which part of the image I should be focussing on?`

Now before we move on to the next example lets talk about what I mean by the word `focussing` and how does that manifest itself in code. To tell the model what is focus on the attention mechanism is going to return a weight. This weight is going to be a number between 0 and 1 and the pixels with higher weight attached to it are supposed to be the pixels the model should focus on. 

So, in this first example, the output of attention mechanism indicated to predict the word `frisbee` th model  should be focussing (or _attending_ to) on the frisbee in the image. 


Now, I think the last example is particularly cool. Before predicting the word `trees` its stating to focus on all the greenery in the background. Even thought the giraffe is dominating the image the attention mechanism is able to tell us what should be the area of focus. 

Before we move on lets talk about this example again but this time lets use some technical jargon. 

These examples are from the paper [Show, Attend and Tell: Neural Image CaptionGeneration with Visual Attention](https://arxiv.org/pdf/1502.03044.pdf)



![](assets/Align_And_Translate_Attention.png)

This depcits the usage of attention in a model used to translate english sentences to french. In this example, the sentence on the X-axis is the input sentence and the sentence on the Y-axis. In this model the attention mechanism is being used in a similar manner to the example we just went through. To predict each word the model uses attention to state which part of the source sentence should be focussed on. 

Now, lets focus on the word `en` in the translation. The attention mechanism is saying you should be focussing on the word `in` and I believe is more fascinating and cool is that it's also saying that to predict the word `en` some information or some amount of context is also being provided by the words `signed` and `August`. 

This example is from the paper which invented the attention mechanism: [NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE](https://arxiv.org/pdf/1409.0473v7.pdf)

I was able to tell attention till where I've predicted (the current state) in return tell us which parts of the input are important. 

___________________________

## I've seen what attention does visually; let's take a step forward ## 

![](assets/Align_And_Translate_Attention.png)

We've already seen that given a list of things the attention mechanism provides a weight for each item in the list. In this example, a list of words is given to the attention mechanism and I wanted to know which words are more important and thus should have more weight. 

Lets take a look at what this looks like in code. But before we take a look at what the function looks like from inside let's take a look at how the attention function is used. 

In [None]:
# Query Entity 
hidden = torch.randn(1,HDDN_DIM) # hidden = entity_hidden

# List of Entities
encoder_outputs = torch.randn(90, 1, ENC_DIM) # encoder_outputs = [entity_1, entity_2, entity_3]

In [None]:
encoder_outputs.shape

torch.Size([90, 1, 5])

In [None]:
# Usage of the attention function. 
encoder_outputs_attention_weights, energy = attention(hidden,encoder_outputs)

# bmm 
context_vector = torch.bmm(
    encoder_outputs_attention_weights.unsqueeze(1),
    encoder_outputs.permute(1, 0, 2))

NameError: name 'attention' is not defined

If there is an overriding objective of this video it is that you the viewer should get comfortable with the last cell and we're gonna discuss this cell at length. Here I'm using the attention function (which I've defined below and we're gonna take a look at it eventually) to calculate a set of weights where encoder_outputs is the "list of things" we saw in our visuals above. 

The Input 
---------------
Here the input variables that I have defined are randomly initialized variables but it is important to note that if we were performing NMT the tensor form in which I have initialized this variables is exactly what the actual variables would look like. So, let's take a look at what the input to the attenion mechanism is supposed to represent: 

**encoder_outputs** : This is the list of entities for which the weights are going to be calculated. In the context of NMT it represents the source sentence after ingesting each token. Here our source sentence has 14 tokens and encoder_outputs has 14 hidden_state vectors. 

**hidden**:This a fixed length vector which represents (for our NMT example) the translated sentence so far. So, if the model has outputted 3 words in french it is supposed to represent the sequence of those 3 words. But, more importantly it is supposed to tell the attention mechanism which word we want to predict next. Here, I would like to take a moment to talk about how to think about this variable in more abstract terms and I'd like to argue that the purpose of the attention mechanism is to query information from a list of vectors. When I'm translating something, I'm querying "Which part of the source sentence I should focus on to predict the next word?" or when I'm generating a caption I'm querying "Which part of the image I should focus on to predict the next word?". So, how I like to think about the _hidden_ variable is that it is used as a query. 




In [None]:
encoder_outputs_attention_weights, energy

(tensor([[0.0897, 0.0583, 0.0661, 0.0786, 0.0655, 0.0809, 0.0639, 0.0688, 0.0652,
          0.0619, 0.0821, 0.0815, 0.0630, 0.0745]], grad_fn=<SoftmaxBackward>),
 tensor([[ 0.1802, -0.2512, -0.1253,  0.0478, -0.1342,  0.0772, -0.1589, -0.0844,
          -0.1381, -0.1903,  0.0922,  0.0851, -0.1732, -0.0048]],
        grad_fn=<SqueezeBackward1>))

In [None]:
encoder_outputs_attention_weights.sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [None]:
encoder_outputs.shape

torch.Size([50, 1, 5])

The Output 
---------------
Now, here the attention func is returning two variables but when you're actually making a model you're never really gonna need `energy`. We've already discussed how attention tells us which things to focus on in a list of entities. It does that by returning a weight for each item in my list. Higher the weight means an entity is more important. So, `encoder_outputs_attention_weights` represents weights for each item in `encoder_outputs`. Now, here's an important property for the weights attention returns. In all the implementations of attention that have been invented yet this property is gonna hold true. The weights are always gonna sum up to 1 and this gives me the opportunity to state that these scores are normalized.

`energy` represents the unnormalized weights.

There's an interesting point to note here. The attention function has no dependence on the number of items in my list. Regadless of the number of items in my list it will use the query object and return a set of weights for each item in _encoder_outputs_

BMM
--------
Now, we're at a place where I have a list of entities from which I want some information, I have a query object which I can use to fetch the information and using these two things and the attention func I have calculated a set of weights telling me which entities in my list are more important than the others. Let's see how these weights are used when I'm developing a model. 

As we've already seen the attention func does not depend on the number of entities in _encoder_outputs_ , so I need a way to use these weights that does not depend on that either. How that is done is a bit tricky and was not intuitive for me when I first went through it. 

The objective of the attention func is to query information and in my model I want to represent the fetched information using a single fixed length vector. So, to do that I multiply each item in encoder_output with it's corresponding weight and I add all the entities together. 

Here's an example

In [None]:
encoder_outputs.shape

torch.Size([14, 1, 5])

In [None]:
encoder_outputs_attention_weights.shape

torch.Size([1, 14])

In [None]:
encoder_outputs_small = torch.tensor([
    [1.,2.], # [1.,2.]*0.7 = [0.7, 1.4]
    [2.,1.], # [2.,1.]*0.2 = [0.4, 0.2]
    [1.,1.]  # [1.,1.]*0.1 = [0.1, 0.1]
]).view([3,1,2])

attn_weights_small = torch.tensor([
    [0.7],
    [0.2],
    [0.1]
]).view(1,3)

In [None]:
encoder_outputs_small.shape

torch.Size([3, 1, 2])

In [None]:
attn_weights_small.shape

torch.Size([1, 3])

In [None]:
torch.bmm(
    attn_weights_small.unsqueeze(1),
    encoder_outputs_small.permute(1, 0, 2))
# This vector is supposed be to 70% of [1., 2.] and 20% of [2., 1.]

tensor([[[1.2000, 1.7000]]])

The context vector obtained is supposed to be a vector representation of: 
1. The image with the focus on the frisbee from the first example. 
2. The source sentence with the focus on the word 'in' in the NMT example. 

## Internals of the Attention Mechanism ## 

ok. After a lot of effort we're at a stage where we can start looking at the internals of the attention function. But I believe if you've been able to follow this lecture so far this next section should be a breeze. 

In [None]:
# ENC_HID_DIM = 5
# DEC_HID_DIM = 5

HDDN_DIM = 5
ENC_DIM = 5
DEC_DIM = 5

attn = nn.Linear(HDDN_DIM + ENC_DIM, DEC_DIM)
v = nn.Linear(DEC_DIM, 1, bias = False)

In [None]:
attn

Linear(in_features=10, out_features=5, bias=True)

In [None]:
hidden = torch.randn(1,HDDN_DIM) # hidden = entity_hidden
encoder_outputs = torch.randn(10, 1, ENC_DIM) # encoder_outputs = [entity_1, entity_2, entity_3]

In [None]:
encoder_outputs.shape #(seq_len, batchsize, vector_size)

torch.Size([10, 1, 5])

In [None]:
# query entity
hidden.shape

torch.Size([1, 5])

In [None]:
def attention(hidden,encoder_outputs):
#     import pdb;pdb.set_trace()
    batch_size = encoder_outputs.shape[1]
    src_len = encoder_outputs.shape[0]

    #repeat decoder hidden state src_len times
    hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
    encoder_outputs = encoder_outputs.permute(1, 0, 2)

    # energy function 
    energy = torch.tanh(
        attn(
            torch.cat(
                (hidden,
                encoder_outputs), dim = 2)
        )
    ) 
    attention = v(energy).squeeze(2)

    return F.softmax(attention, dim=1), attention 

Walking Back 
--------------------

You should be familiar with the cell below as we've just discussed this. Let's try to walk back from what we know of this cell to what we can deduce without actually looking at the internals. We have a query object `hidden`, a list of entities `encoder_outputs`. The attention func is gonna return a set of normalized weights for each item in my list. 


Now, I think we can deduce the following from this information that we have with us: 
We're going to calculate some score by performing **some operation** between the query object (`hidden`) and each item in `encoder_outputs` and after calculating all the scores we're going to normalize the scores so that they'll fullfill the property of summing up to 1. 

In [None]:
encoder_outputs_attention_softmax_weights, energy  = attention(hidden, encoder_outputs)

# bmm 
context_vector = torch.bmm(
    encoder_outputs_attention_softmax_weights.unsqueeze(1),
    encoder_outputs.permute(1, 0, 2))

The probability $\alpha_{ij}$ , or its associated energy $e_{ij}$, reflects the importance of the annotation $h_{j}$ with respect to the previous hidden states $i−1$ in deciding the next states $i$ and generating $y_i$ . Intuitively,this implements a mechanism of attention in the decoder.  The decoder decides parts of the source sentence to pay attention to. 

In [None]:
HDDN_DIM = 5
ENC_DIM = 5
DEC_DIM = 5

attn = nn.Linear(HDDN_DIM + ENC_DIM, DEC_DIM)
v = nn.Linear(DEC_DIM, 1, bias = False)

In [None]:
hidden = torch.randn(1,HDDN_DIM) # hidden = entity_hidden
encoder_outputs = torch.randn(10, 1, ENC_DIM) # encoder_outputs = [entity_1, entity_2, entity_3]

### The Attention Function ### 

Let's zoom into the attention func. I've taken every line and put it in a different cell. 

In [None]:
encoder_outputs.shape

torch.Size([10, 1, 5])

In [None]:
batch_size = encoder_outputs.shape[1]
src_len = encoder_outputs.shape[0]

In [None]:
#repeat decoder hidden state src_len times
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
encoder_outputs = encoder_outputs.permute(1, 0, 2)

# Since I want to calculate a score by performing some operation between each item in `encoder_outputs` 
# and the query object I'm gonna need a copies of the query object equal to the number of items in 
# `encoder_outputs`

In [None]:
# energy function 
energy = torch.tanh(
    attn(
        torch.cat(
            (hidden,
            encoder_outputs), dim = 2)
    )
) 

attention = v(energy).squeeze(2)

In [None]:
F.softmax(attention, dim=1)

tensor([[0.0732, 0.0731, 0.0642, 0.0703, 0.0646, 0.0657, 0.0823, 0.0777, 0.0837,
         0.0767, 0.0706, 0.0537, 0.0651, 0.0790]], grad_fn=<SoftmaxBackward>)

# The Energy Function # 

![](assets/energy_functions.png)

Now, a question worth asking here is is "Is the Energy Function" defined above special? Are there others forms to it? The answer is Yes. The picture describes a few other forms of the attention function. I believe the only limitation to the Energy Function is that I should be able to backpropogate through it. So, let's see another form of the energy function in the example below. 

I have taken this image from [Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/pdf/1508.04025.pdf). Which further introduces another kind of energy function. I'm not going to go into details here but it introduces a "local" attention mechanism. This mechanism instead of taking as input all the words from our NMT example above, takes a lesser number of entities. I believe the NMT [implementation](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) available on the PyTorch website uses that implementation. 

## General Attention ## 

In [None]:
n_hidden = 5
attn = nn.Linear(n_hidden, n_hidden)

In [None]:
def get_att_weight(dec_output, enc_outputs):  # get attention weight one 'dec_output' with 'enc_outputs'
    n_step = len(enc_outputs)
    attn_scores = torch.zeros(n_step)  # attn_scores : [n_step]

    for i in range(n_step):
        attn_scores[i] = get_att_score(dec_output, enc_outputs[i])

    # Normalize scores to weights in range 0 to 1
    return F.softmax(attn_scores).view(1, 1, -1), attn_scores

def get_att_score(dec_output, enc_output):  # enc_outputs [batch_size, num_directions(=1) * n_hidden]
    score = attn(enc_output)  # score : [batch_size, n_hidden]
    return torch.dot(dec_output.view(-1), score.view(-1))  # inner product make scalar value

In [None]:
encoder_outputs = torch.randn(3, 1, 5)
hidden = torch.randn(1, 1, 5)

In [None]:
encoder_outputs_attention_softmax_weights, energy = get_att_weight(hidden, encoder_outputs)

  return F.softmax(attn_scores).view(1, 1, -1), attn_scores


In [None]:
encoder_outputs_attention_softmax_weights

tensor([[[3.4176e-04, 8.2757e-01, 1.7209e-01]]], grad_fn=<ViewBackward>)

In [None]:
energy

tensor([-5.8196,  1.9725,  0.4020], grad_fn=<CopySlices>)

## Further Reading ## 
Before we move on to the next topic we've covered enough ground here for you to be able to look at implementations of the attention mechanisms and/or read research papers that utilize the attention mechanism for various tasks. 

[**NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE**](https://arxiv.org/pdf/1409.0473v7.pdf): This paper introduced the attention mechanism and it did that in the context of the NMT task. [Here](https://paperswithcode.com/paper/neural-machine-translation-by-jointly) are the implementations of this paper on paperwithcode. 


[**Effective Approaches to Attention-based Neural Machine Translation**](https://arxiv.org/pdf/1508.04025.pdf): This paper introduces a new energy function as I discussed above. 

## Self-Attention ## 

In my opinion the attention mechanism was an important discovery because when coupled with RNNs and CNNs they improved upon the SOTA results. However, self-attention is for a reason a little different than that. In 2017, the researchers at Google published a research paper detailing a new architecture that they had invented and this architecture had self-attention as a core component. This architecture is called `The Transformer`. The invention of this architecture was something of a watershed moment (even though it feels like they happen every other day in the Deep Learning field). This standalone model is really powerful and the performance of this models on a whole host NLP tasks has been truly awesome. 

Since 2017 a large amount of research has be done to improve upon the Transformer model. The number of such research papers published has been so large that HugginFace created a library called [transformers](https://huggingface.co/transformers/). This library has the implementations of a large number of research paper which can be utilized for NLP tasks. 

To explore self-attention we're going to be using this library and we will try to decode the implementation of self-attention in this library. 

In [None]:
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention
from transformers.models.distilbert.configuration_distilbert import DistilBertConfig

In [None]:
model_checkpoint = "distilbert-base-uncased"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
tokenizer 

PreTrainedTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [None]:
def embeddify(text):
    token_ids = tokenizer(text)['input_ids']
    _len = len(token_ids)
    return embeddings(torch.tensor(torch.tensor(tokenizer(text)["input_ids"]).view(1,_len))), tokenizer.convert_ids_to_tokens(token_ids)
    

In [None]:
multi_head_attn = MultiHeadSelfAttention(config)

In [None]:
embeddified_text, tokens = embeddify('Ronaldo is one of the best football players in the world')
# x = torch.randn(1,10,config.dim) # (bs, seq_length, dim)

  return embeddings(torch.tensor(torch.tensor(tokenizer(text)["input_ids"]).view(1,_len))), tokenizer.convert_ids_to_tokens(token_ids)


In [None]:
embeddified_text.shape

torch.Size([1, 14, 768])

In [None]:
tokens

['[CLS]',
 'ronald',
 '##o',
 'is',
 'one',
 'of',
 'the',
 'best',
 'football',
 'players',
 'in',
 'the',
 'world',
 '[SEP]']

In [None]:
mask = torch.ones(1,14)

In [None]:
multi_head_attn_op = multi_head_attn(
    embeddified_text,
    embeddified_text,
    embeddified_text,
    mask)

In [None]:
len(multi_head_attn_op)

1

In [None]:
multi_head_attn_op[0].shape

torch.Size([1, 14, 768])

I realize I have implemented a lot of code here but I believe most of it should not be that scary. Hugginface has also created a library for Tokenization. For NLP tasks prastices have been developed that keeping all other things the same give better results. The Tokenizer library has encapsulated those practices. 

In this demo I am using the MultiHeadAttention class that is being used in the DistillBert implementation. DistilBert is a improvement on top of the Transformer architecture. 

The embeddify func has retuned a tensor creating a vector space representation for each token in a sentence and I've passed that tensor as input to MultiHeadAttention. As output it seems like it has returned a vector for each entry in my input. Let's check out a visual representation of what happened here: 

![](assets/transformer_self-attention_visualization.png)

Now, someone when discovering self-attention looked at some version the `embeddified_text` tensor representing a sequence of entities where each entity is represented by a fixed length vector tried the following: 

They wanted to use attention to use each entity to extract information from itself. So, here the vector representing each word is treated as the query object and attention is being used to fetch a set of weights for the entities in itself. I hope why this is called self-attention is clear now. We can now also decode what the output `multi_head_attn_op[0]` is. 

So, for every entity: 
1. I implement a energy function where the entity is the query and the "list of things" is itself. 
2. I will apply SoftMax on the weights on the output of the energy func. 
3. I'm gonna multiply the weights with each vector and add them together to get a context vector like before. 

So, for each entity in my sequence I will get a context vector. The weightage being depicted in our visual above is supposed to represent the softmaxed weights where the word 'it' was the query and using the weights we will get a context vector. Now, what does the weights tell us? Its telling us that the word 'it' has a strong relationship with the words 'The animal' as compared to the word 'because' and if we read the sentence that kinda makes sense. 


Before we start looking at how this is implemented a note on what MultiHead means. MultiHead means that instead of having one vector representing each word,we will have one word being represented by multiple vectors. This will lead to attention mechanism will be applied multiple times in parallel and multiple context vectors being created. Best to explain it further by code: 

In [None]:
chatuur_multi_head_attn = Chatuur_MultiHeadSelfAttention(config)

In [None]:
op = chatuur_multi_head_attn(embeddified_text,
                             embeddified_text,
                             embeddified_text,
                             mask)

> [0;32m<ipython-input-819-8f2a1170a3c9>[0m(46)[0;36mforward[0;34m()[0m
[0;32m     44 [0;31m        """
[0m[0;32m     45 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 46 [0;31m        [0mbs[0m[0;34m,[0m [0mq_length[0m[0;34m,[0m [0mdim[0m [0;34m=[0m [0mquery[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     47 [0;31m        [0mk_length[0m [0;34m=[0m [0mkey[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m        [0;31m# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> n
> [0;32m<ipython-input-819-8f2a1170a3c9>[0m(47)[0;36mforward[0;34m()[0m
[0;32m     45 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0

ipdb> 
> [0;32m<ipython-input-819-8f2a1170a3c9>[0m(67)[0;36mforward[0;34m()[0m
[0;32m     65 [0;31m        [0mv[0m [0;34m=[0m [0mshape[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mv_lin[0m[0;34m([0m[0mvalue[0m[0;34m)[0m[0;34m)[0m  [0;31m# (bs, n_heads, k_length, dim_per_head)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m---> 67 [0;31m        [0mq[0m [0;34m=[0m [0mq[0m [0;34m/[0m [0mmath[0m[0;34m.[0m[0msqrt[0m[0;34m([0m[0mdim_per_head[0m[0;34m)[0m  [0;31m# (bs, n_heads, q_length, dim_per_head)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m        [0mscores[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mmatmul[0m[0;34m([0m[0mq[0m[0;34m,[0m [0mk[0m[0;34m.[0m[0mtranspose[0m[0;34m([0m[0;36m2[0m[0;34m,[0m [0;36m3[0m[0;34m)[0m[0;34m)[0m  [0;31m# (bs, n_heads, q_length, k_length)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m        [0mmask[0m [0;34m=[0m [0;34m([0

ipdb> n
> [0;32m<ipython-input-819-8f2a1170a3c9>[0m(80)[0;36mforward[0;34m()[0m
[0;32m     78 [0;31m[0;34m[0m[0m
[0m[0;32m     79 [0;31m        [0mcontext[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mmatmul[0m[0;34m([0m[0mweights[0m[0;34m,[0m [0mv[0m[0;34m)[0m  [0;31m# (bs, n_heads, q_length, dim_per_head)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 80 [0;31m        [0mcontext[0m [0;34m=[0m [0munshape[0m[0;34m([0m[0mcontext[0m[0;34m)[0m  [0;31m# (bs, q_length, dim)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        [0mcontext[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mout_lin[0m[0;34m([0m[0mcontext[0m[0;34m)[0m  [0;31m# (bs, q_length, dim)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     82 [0;31m[0;34m[0m[0m
[0m
ipdb> p context.shape
torch.Size([1, 12, 14, 64])
ipdb> p weights.shape
torch.Size([1, 12, 14, 14])
ipdb> v.shape
torch.Size([1, 12, 14, 64])
ipdb> q


BdbQuit: 

In [None]:
# Here we have the MultiHeadSelfAttention from the trnsformer library. 
class Chatuur_MultiHeadSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)

        self.pruned_heads = set()

    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
        """
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)

        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
        """
        import pdb;pdb.set_trace()
        bs, q_length, dim = query.size()
        k_length = key.size(1)
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        # assert key.size() == value.size()

        dim_per_head = self.dim // self.n_heads

        mask_reshp = (bs, 1, 1, k_length)

        def shape(x):
            """ separate heads """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """ group heads """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        # query object 
        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        
        # list of things
        # Discuss the rearrangement for multi heads. 
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)

        # Attention All you Need paper states tha this operation improves results. 
        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        
        
        # Dot Energy function. 
        # show we have a score for each word being treated as query 
        # and performing attention on itself. 
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        

        
        # Will talk about this later
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)

        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        if output_attentions:
            return (context, weights)
        else:
            return (context,)

## Conclusion ## 

Now, I am leaving a few questions for the next video: 
1. The purpose of q_lin, k_lin and v_lin. 
2. What is a mask. 