![banner](https://raw.githubusercontent.com/priyammaz/HAL-DL-From-Scratch/main/src/visuals/banner.png)

# Attention

Attention networks have become crucial in state of the art architectures, namely Transformers! Today we will be delving a bit deeper into attention and how it works! Although attention was mainly intended for use in sequence modeling, it has found its way into Computer Vision, Graphs and basically every domain, demonstrating the flexibility of the architecture. Lets discuss this from a sequence modeling perspective today though just to build intuition on how this works. To start the explanation, lets reference back the original sequence modeling mechanism: **Recurrent Neural Networks**

## Recap: Recurrent Neural Networks
<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/recurrent_neural_network_diagram.png" width="800"/>
</div>

In recurrent neural networks, what we typically do is take our sequence and pass in a single timestep at a time and produce an output. This means when we pass in $x_1$ we create a hidden state $h_1$ that captures all the relevant information in the input, and this hidden state then is used to produce the output $y_1$. Now what makes it an RNN is when we pass in the second timestep $x_2$ to produce the hidden state $h_2$, the hidden state already contains information about the past $h_1$! Therefore our output of $y_2$ is informed both by information from $x_2$ and $x_1$ encoded through the hidden states. If we keep this going, when we want to make a prediction at $y_{100}$, we will be using a hidden state that has encoded information of all the inputs $x_1$ to $x_{100}$. Everything explained so far is a causal RNN, basically to make a prediction of sometime timestep $t$, we can use all the input timesteps $<=t$. We can easily expand this though to make a bidirectional RNN, where to make a prediction at time $t$, we can look at the entire sequence as well. In this case we will really have two hidden states, one that looks backwards and another that looks forward! Whether you use causal or bidirectional depends a lot on what you want to do. If you want to do Name Entity Recognition (i.e. determine if each word in a sentence is an entity), you can look at the entire sentence to do this. On the other hand if you want to forecast the future, like a stock price, then you have to use causal as you can only look at the past to predict the future. 

All this sounds well and good, but there was one glaring problem: Memory. The hidden states we use to encode the history can only contain so much information, i.e. as the sequence length becomes longer the model will start to forget. This matters a lot for things like Natural Language Processing, as there may be imporant relations between parts of a book that are pages, or even chapters, apart. To solve this issue, Attention Augmented RNNs were introduced in the paper [Neural Machine Translation By Jointly Learning To Align and Translate](https://arxiv.org/pdf/1409.0473). 

## Attention Augmented RNN

If I had to use two words to define attention it would be: **Weighted Average**. In the paper, the call the hidden states *annotations*, but they are the same thing! So lets go back to our RNN again, before we do our prediction for $y_t$, we have a sequence of hidden states $h_t$ that contain the information about the sequence $x_t$ itself produced from the RNN mechanism. The problem is again, $h_t$ for large values of $t$ will have forgotten imporant information about early $x_t$ values with small values of $t$. So what if we got everyone to know each other again? We can produce a context vector $c_i$ that is a weighted average of all the hidden states in the case of a bidirectional architecture, or just the previous hidden states in a causal architecture. This means at any time of the context vector $c_t$, it will be a weighted average of all of the timesteps so it is reminded about more distant timesteps, solving our *memory* problem!

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/rnn_with_attention.png" width="800"/>
</div>

Now I keep saying weighted average, and this is because for one of the timesteps, the model has to learn the weights of what is the most important information to know at those times, and then weight them higher! As per the paper, the weights were learned through an alignment model, which was just a feedforward network, that scores how well hidden states as time $t$ is related to those around it in the sequence. These scores were then passed through a softmax to ensure all the learned weights sum upto 1, and then the context vectors are computed based on them! This means every context vector is a customized weighted average that learned exactly what information to put empahsis on at every timestep of the context vectors. 

### Problems

There were some issues with this though, some which were already known about RNNs:
- **Efficient but Slow**: The RNN mechanism has a for-loop through the sequence making training very slow, but inference was efficient
- **Lack of Positional Information**: Our context vectors are just weighted averages of hidden, there is no information about position or time, but obviously in most sequence tasks, the order in your data appears is very important
- **Redundancy**: We are effectively learning the same thing twice here, the hidden states encode sequential information, but the attention mechanism also encodes sequential information

### Attention is All You Need!

The groundbreaking paper, [Attention is All You Need](https://arxiv.org/pdf/1706.03762) solved all of the problems above, but added a new one: Computational Cost. Lets first look at what the proposed Attention mechanism is doing!


<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/attention_mechanism_visual.png" width="800"/>
</div>

The input is a sequence of embedding vectors and the output is a sequence of context vectors. Lets quickly look at the formulation for this:

$$\text{Attention}(Q,K,V) = \text{Softmax}(\frac{QK^T}{\sqrt{d_e}})V$$

We see some new notation show up now, $Q$, $K$, $V$, so lets define them:

- $Q$: Queries, they are the token we are interested in
- $K$: Keys, they are the other tokens we want to compare our query against
- $V$: Values, they are the values we will weight in our weighted average

This is a little weird so lets step through it! First important note, the $Q$, $K$, and $V$ are three projections of our original data input $X$. This basically means we have three linear layers that all take the same input $X$ to produce our $Q$, $K$, $V$. 

### Step 1: Compute the Attention Matrix with $Softmax(QK^T)$

So the first step is the computing the $Softmax(QK^T)$, where Q and K both have the shape (Sequence Length x Embedding Dimension). The output of this computation will be sequence length x sequence length. This is what it looks like!

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/computing_attention.png?raw=true" width="800"/>
</div>

In the image above, I also applied the softmax (not shown for simplicity), so each row of the attention matrix adds up to 1 (like probabilities).

**Recap: Dot Product**

As a quick reminder, this whole mechanism depends on the dot product, and more specifically, its geometric interpretation

$a\cdot b = \sum_{i=1}^n a_i*b_i = |a||b|cos(\theta)$

What the dot product really signifies is the similarity between vectors. Remember the cosine of 0 is just 1, so the highest possible cosine value would be when the vectors $a$ and $b$ point in the exact same direction. This means vectors that are similar in direction have higher magnitude. 

**Recap: Matrix Multiplication**

Also remember, matrix multiplication is basically just a bunch of dot products, repeating the multiply/add operation repeatedly. If we are multiplying matrix $A$ with matrix $B$, what we are really doing is doing the dot product of every row of $A$ and every column of $B$!

So with our quick recaps, lets go back to the image above, when we are multiplying $Q$ by $K^T$, we are multiplying each vector in the sequence $Q$ by each vector in the sequence $K$ and computing their dot product similarity. Again, $Q$ and $K$ are just projections of the original data $X$, so really we are just computing the similarity between every possible combination of timesteps in $X$. We also could have just done $XX^T$, this would technically be the same thing, but by including the projections of $X$ rather than using the the raw inputs themselves, we allow the model to have more learnable parameters so it can futher accentuate similarities and differences between different timesteps!

The final result of this operation is the attention matrix, that computes the similarity between every possible pairs of tokens. 

**Note** I didn't inlude anything about the $\frac{1}{\sqrt{d_e}}$ term in the formula. This is just a normalization constant that ensures our variance of the attention matrix isn't too large after our matrix multiplication. This just leads to more stable training!

### Step 2: Weighting the Values Matrix

Now that we have our similarities of how each timestep is related to all the other timesteps, we can now do our weighted average! After the weighted average computation, each vector for each timestep isn't just the data of the timestep but rather a weighted average of all the vectors in the sequence and how they are related to that timestep of interest. 


<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/encoder_attention_vis.png?raw=true" width="800"/>
</div>

LinkedIn
