![banner](https://raw.githubusercontent.com/priyammaz/HAL-DL-From-Scratch/main/src/visuals/banner.png)

# Attention

Attention networks have become crucial in state of the art architectures, namely Transformers! Today we will be delving a bit deeper into attention and how it works! Although attention was mainly intended for use in sequence modeling, it has found its way into Computer Vision, Graphs and basically every domain, demonstrating the flexibility of the architecture. Lets discuss this from a sequence modeling perspective today though just to build intuition on how this works. To start the explanation, lets reference back the original sequence modeling mechanism: **Recurrent Neural Networks**

## Recap: Recurrent Neural Networks
<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/recurrent_neural_network_diagram.png" width="800"/>
</div>

In recurrent neural networks, what we typically do is take our sequence and pass in a single timestep at a time and produce an output. This means when we pass in $x_1$ we create a hidden state $h_1$ that captures all the relevant information in the input, and this hidden state then is used to produce the output $y_1$. Now what makes it an RNN is when we pass in the second timestep $x_2$ to produce the hidden state $h_2$, the hidden state already contains information about the past $h_1$! Therefore our output of $y_2$ is informed both by information from $x_2$ and $x_1$ encoded through the hidden states. If we keep this going, when we want to make a prediction at $y_{100}$, we will be using a hidden state that has encoded information of all the inputs $x_1$ to $x_{100}$. Everything explained so far is a causal RNN, basically to make a prediction of sometime timestep $t$, we can use all the input timesteps $<=t$. We can easily expand this though to make a bidirectional RNN, where to make a prediction at time $t$, we can look at the entire sequence as well. In this case we will really have two hidden states, one that looks backwards and another that looks forward! Whether you use causal or bidirectional depends a lot on what you want to do. If you want to do Name Entity Recognition (i.e. determine if each word in a sentence is an entity), you can look at the entire sentence to do this. On the other hand if you want to forecast the future, like a stock price, then you have to use causal as you can only look at the past to predict the future. 

All this sounds well and good, but there was one glaring problem: Memory. The hidden states we use to encode the history can only contain so much information, i.e. as the sequence length becomes longer the model will start to forget. This matters a lot for things like Natural Language Processing, as there may be imporant relations between parts of a book that are pages, or even chapters, apart. To solve this issue, Attention Augmented RNNs were introduced in the paper [Neural Machine Translation By Jointly Learning To Align and Translate](https://arxiv.org/pdf/1409.0473). 

## Attention Augmented RNN

If I had to use two words to define attention it would be: **Weighted Average**. In the paper, the call the hidden states *annotations*, but they are the same thing! So lets go back to our RNN again, before we do our prediction for $y_t$, we have a sequence of hidden states $h_t$ that contain the information about the sequence $x_t$ itself produced from the RNN mechanism. The problem is again, $h_t$ for large values of $t$ will have forgotten imporant information about early $x_t$ values with small values of $t$. So what if we got everyone to know each other again? We can produce a context vector $c_i$ that is a weighted average of all the hidden states in the case of a bidirectional architecture, or just the previous hidden states in a causal architecture. This means at any time of the context vector $c_t$, it will be a weighted average of all of the timesteps so it is reminded about more distant timesteps, solving our *memory* problem!

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/rnn_with_attention.png" width="800"/>
</div>

Now I keep saying weighted average, and this is because for one of the timesteps, the model has to learn the weights of what is the most important information to know at those times, and then weight them higher! As per the paper, the weights were learned through an alignment model, which was just a feedforward network, that scores how well hidden states as time $t$ is related to those around it in the sequence. These scores were then passed through a softmax to ensure all the learned weights sum upto 1, and then the context vectors are computed based on them! This means every context vector is a customized weighted average that learned exactly what information to put empahsis on at every timestep of the context vectors. 

### Problems

There were some issues with this though, some which were already known about RNNs:
- **Efficient but Slow**: The RNN mechanism has a for-loop through the sequence making training very slow, but inference was efficient
- **Lack of Positional Information**: Our context vectors are just weighted averages of hidden, there is no information about position or time, but obviously in most sequence tasks, the order in your data appears is very important
- **Redundancy**: We are effectively learning the same thing twice here, the hidden states encode sequential information, but the attention mechanism also encodes sequential information

### Attention is All You Need!

The groundbreaking paper, [Attention is All You Need](https://arxiv.org/pdf/1706.03762) solved all of the problems above, but added a new one: Computational Cost. Lets first look at what the proposed Attention mechanism is doing!


<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/attention_mechanism_visual.png" width="800"/>
</div>

The input is a sequence of embedding vectors and the output is a sequence of context vectors. Lets quickly look at the formulation for this:

$$\text{Attention}(Q,K,V) = \text{Softmax}(\frac{QK^T}{\sqrt{d_e}})V$$

We see some new notation show up now, $Q$, $K$, $V$, so lets define them:

- $Q$: Queries, they are the token we are interested in
- $K$: Keys, they are the other tokens we want to compare our query against
- $V$: Values, they are the values we will weight in our weighted average

This is a little weird so lets step through it! First important note, the $Q$, $K$, and $V$ are three projections of our original data input $X$. This basically means we have three linear layers that all take the same input $X$ to produce our $Q$, $K$, $V$. 

### Step 1: Compute the Attention Matrix with $Softmax(QK^T)$

So the first step is the computing the $Softmax(QK^T)$, where Q and K both have the shape (Sequence Length x Embedding Dimension). The output of this computation will be sequence length x sequence length. This is what it looks like!

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/computing_attention.png?raw=true" width="800"/>
</div>

In the image above, I also applied the softmax (not shown for simplicity), so each row of the attention matrix adds up to 1 (like probabilities).

**Recap: Dot Product**

As a quick reminder, this whole mechanism depends on the dot product, and more specifically, its geometric interpretation

$a\cdot b = \sum_{i=1}^n a_i*b_i = |a||b|cos(\theta)$

What the dot product really signifies is the similarity between vectors. Remember the cosine of 0 is just 1, so the highest possible cosine value would be when the vectors $a$ and $b$ point in the exact same direction. This means vectors that are similar in direction have higher magnitude. 

**Recap: Matrix Multiplication**

Also remember, matrix multiplication is basically just a bunch of dot products, repeating the multiply/add operation repeatedly. If we are multiplying matrix $A$ with matrix $B$, what we are really doing is doing the dot product of every row of $A$ and every column of $B$!

So with our quick recaps, lets go back to the image above, when we are multiplying $Q$ by $K^T$, we are multiplying each vector in the sequence $Q$ by each vector in the sequence $K$ and computing their dot product similarity. Again, $Q$ and $K$ are just projections of the original data $X$, so really we are just computing the similarity between every possible combination of timesteps in $X$. We also could have just done $XX^T$, this would technically be the same thing, but by including the projections of $X$ rather than using the the raw inputs themselves, we allow the model to have more learnable parameters so it can futher accentuate similarities and differences between different timesteps!

The final result of this operation is the attention matrix, that computes the similarity between every possible pairs of tokens. 

**Note** I didn't inlude anything about the $\frac{1}{\sqrt{d_e}}$ term in the formula. This is just a normalization constant that ensures our variance of the attention matrix isn't too large after our matrix multiplication. This just leads to more stable training!

### Step 2: Weighting the Values Matrix

Now that we have our similarities of how each timestep is related to all the other timesteps, we can now do our weighted average! After the weighted average computation, each vector for each timestep isn't just the data of the timestep but rather a weighted average of all the vectors in the sequence and how they are related to that timestep of interest. 


<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/encoder_attention_vis.png?raw=true" width="800"/>
</div>

The output of this operation gives us the sequence of context vectors!

### Enforcing Causality

What we have seen so far is the equivalent to a Bidirectional RNN. The weighted average operation we are doing is between a timestep of interest and all timesteps before and after it. If we wanted a causal model, where a context vector only depends on the timesteps before it, then we need to apply a causal mask to our attention mechanism. 

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/causal_masking.png" width="800"/>
</div>

As you can see, we apply a mask to all values of $t$ where the index of the column values (our key index) is greater than the index of the row value (our value index). In practice, once we apply this mask to our attention matrix, we can then multiply by our values. You will see that the context vector at time $t$ is only then dependent on previous timesteps, as we make sure future vectors of $V$ are zeroed out!
<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/decoder_attention_vis.png" width="800"/>
</div>

### Lets Build This!

Now that we have everything we need, we can build it! We wont be trainig any models now, just defining and exploring the architecture here. To do so, we will define some data in the form of `Batch x Sequence Length x Embed Dim`. The Embedding dimension is basically, what dimension vector do we want to use to represent a single timestep, and the sequence length is how many timesteps there are in total. 

In [2]:
import torch
import torch.nn as nn

### Lets Define some Random Data ###
batch_size = 4
sequence_length = 64
embed_dim = 128

x = torch.randn(batch_size, sequence_length, embed_dim)
print("Shape of Input is:", x.shape)

Shape of Input is: torch.Size([4, 64, 128])


### Implement Attention Without Any Learnable Parameters $\text{Softmax}(\frac{XX^T}{\sqrt{d_e}})X$

This whole attention operation is again very flexible and there is technically no reason to have any learnable parameters in its formulation (other than the obvious for wanting better predictive performance). So lets quickly just implement the formula as is using raw inputs $X$ rather than doing any learned projections of the data. 

**Step 1**

Compute $XX^T$ which will provide the similarity score between every pair of vectors in $X$. This will be contained inside a `batch x sequence_length x sequence_length` matrix

**Step 2**

After computing the similarity score, we can check and see that the variance of the similarity matrix is extremely high, this is the main reason for the normalization of dividing by the square root of the embedding dimension. In the end, the similarity scores are passed through a softmax to compute a probability vector, dividing by a constant basically acts as a temperature parameter to cool the distribution and provide more stable training!

**Step 3**

Each row of our `sequence_length x sequence_length` matrix is the similarity of how one timestep is related to all other timesteps! What we want to do is, instead of raw similarity scores, we will convert them to probabilities, so when we do the weighted average on our values matrix, the weights add up to 1! 

In [3]:
### First compute XX^T for similarity score between every pair of tokens ###
similarity = (x @ x.transpose(1,2))

### Normalize the Similarity Scores ###
print("Prenormalization Variance:", similarity.var())
similarity_norm = similarity / (embed_dim**0.5)
print("Normed Similarity Variance:", similarity_norm.var())

### Check the Shape of our Similarity Tensor ###
print("Shape of Normed Similarity:", similarity_norm.shape)

### Compute similarity on every row of the attention matrix (i.e along the last dimension) ###
attention_mat = similarity_norm.softmax(dim=-1)

### Verify each row adds up to 1 ###
summed_attention_mat = attention_mat.sum(axis=-1)
print("Everything Equal to One:", torch.allclose(summed_attention_mat, torch.ones_like(summed_attention_mat)))

### Multiply our Attention Matrix against its Values (X in our case) for our Weighted Average ###
context_vectors = attention_mat @ x

print("Output Shape:", context_vectors.shape)

Prenormalization Variance: tensor(367.6834)
Normed Similarity Variance: tensor(2.8725)
Shape of Normed Similarity: torch.Size([4, 64, 64])
Everything Equal to One: True
Output Shape: torch.Size([4, 64, 128])


Thats it! This is basically all the attention computation is doing mathematically, we will add in the learnable parameters in just a bit. Something important to bring to your attention is the input shape of $x$ and our output shape of the context vectors are identical. Again, the input is the raw data, the output is weighted averaged of how every token is realted to all the other ones. But the shapes not changing is quite convenient, and allows us to stack together a bunch attention mechanisms on top of one another! 

### Lets Add Learnable Parameters 

This time, instead of using $X$ as our input, we will create our three projections of $X$ (Queries, Keys and Values), but then repeat the operation we just did here! Also for convenience, I will wrap it all in a PyTorch class so we can continue adding stuff onto it as we go on!

Now what are these projections exactly? The are pointwise (or per timestep) projections! Remember, in our example here, each timestep is encoded by a vector of size 128. We will create three learnable weight matricies, incorporated inside the Linear modules in PyTorch, that take these 128 numbers per timestep and projects them to another 128 numbers (it is typical to keep the embedding dimension the same). This is a per timestep operation not across timesteps operation (across timesteps occurs within the attention computation). Obviously, PyTorch will accelerate this per timestep operation by doing it in parallel, but regardless, different timesteps dont get to see each other in the projection step. 

In [6]:
class Attention(nn.Module):
    def __init__(self, embedding_dimension):
        super().__init__()

        self.embed_dim = embedding_dimension 
        
        ### Create Pointwise Projections ###
        self.query = nn.Linear(embedding_dimension, embedding_dimension)
        self.key = nn.Linear(embedding_dimension, embedding_dimension)
        self.value = nn.Linear(embedding_dimension, embedding_dimension)

    def forward(self, x):

        ### Create Queries, Keys and Values from X ###
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        ### Do the same Computation from above, just with our QKV Matricies instead of X ###
        similarity = (q @ k.transpose(1,2)) / (self.embed_dim ** 0.5)
        attention  = similarity.softmax(axis=-1)
        output = attention @ v

        return output

attention = Attention(embedding_dimension=128)
output = attention(x)
print(output.shape)

torch.Size([4, 64, 128])


### MultiHeaded Attention

Now we have a small problem! Remember, the Attention matrix encodes the similarity between each pair of timesteps in your sequence. But in many cases, language being a prime example, there can be different types of relationships between different pairs of words, but our attention computation is restricted to only learn one of them. The solution to this is **MultiHeaded Attention**. Inside each attention computation, what if we have 2 attention matricies, or 8 or however many we want! The more we have the larger diversity of relationships we can learn!

#### Single Headed Attention Recap

Lets summarize everything we have seen so far with a single visual, and we will call this a Single Head of attention. We will also have our embedding dimension for each word in the sequence be 9, and the sequence length is 8. 

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/single_headed_attention_visual.png" width="800"/>
</div>

This is again called single headed attention because we only compute a single attention matrix following the logic above! 

#### Moving to MultiHeaded Attention

For multiheaded attention there isn't really a lot changing. Remember, to create our $Q$, $K$, $V$ in single headed attention, we have 3 linear projection layers that take in the embedding dimension and output the same embedding dimension (in our case it takes in 9 and outputs 9). But in multiheaded attention, we can actually reduce our embedding dimension to a smaller value, do the attention computation on the tokens with this condensed embedding dimension, repeat it a bunch of times, and then concatenate together the outputs. 

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/multiheaded_attention_visual.png" width="800"/>
</div>

It is general practice to have the number of heads you pick to be a divisor of the embedding dimension. For example, in our case, our original embedding dimension is 9, so we can pick 3 heads because 9 is divisible by 3. This also means our head dimension would now be 3, because 9/3 = 3. In typical transformers, the embedding dimension is 768, and they typically have 12 heads of attention. This means each head of attention will have a dimension of 64 because 768/12 = 64. 

The main reason we want it to evenly divide is because we have three heads, each takes in an embedding dimension of 9 and compresses to 3 before computing attention, and then outputs a tensor of embedding size 3. We can then take our 3 tensors, each having an embedding dimension of 3, concatenate them together, returning us back to the 9 that we began with! Again, this is just for convenience, so the embedding dimension of the input and output tensor dont change in any way. Last problem is, each head of attention is computed individually, so the final concatenated tensor has a bunch of heads of attention packed together, but we never got to share information between the different heads of attention. This is why we have the final head projection, that will take in the embedding dimension of 9 from our concatenated tensor, and output an embedding dimension of 9, therefore meshing information across the heads of our embedding dimension.

Lets go ahead and build this module as shown in the figure above! Basically, each head will have 3 projection layers for Q,K,V, we will perform the attention computation, and then stick all the results back together at the end!

In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dimension, num_heads):
        super().__init__()

        ### Make sure Embedding Dimension is Divisible by Num Heads ###
        assert embedding_dimension % num_heads == 0, f"Make sure your embed_dim {embedding_dimension} is divisible by the number of heads {num_heads}"

        self.embed_dim = embed_dim
        self.num_heads = num_heads

        ### Compute Head Dimension ###
        self.head_dim = self.embed_dim // self.num_heads


        ### Create a List of Lists which has all our Q,K,V projections for each head ###
        self.multihead_qkv = nn.ModuleList()

        ### For head Head create the QKV ###
        for head in range(self.num_heads):

            ### Create a dictionary of the 3 projection  layers we need ###
            qkv_proj = nn.ModuleDict(
                [
                    ["Q", nn.Linear(self.embed_dim, self.head_dim)],
                    ["K", nn.Linear(self.embed_dim, self.head_dim)],
                    ["V", nn.Linear(self.embed_dim, self.head_dim)],
                ]
            )

            ### Store Dictionary in List ###
            self.multihead_qkv.append(qkv_proj)

        ### Create final Projection layer, it will be applied to the concatenated heads will have shape embed_dim again ###
        self.head_mesh = nn.Linear(self.embed_dim, self.embed_dim)
        
    def forward(self, x):

        ### Create a list ot store each heads output ###
        head_outs = []
        
        ### Loop Through Each head of Attention ###
        for head in self.multihead_qkv:

            ### Access layers like a dictionary (ModuleDict) ###
            ### q,k,v will be (Batch x Seq len x head_dim)
            q = head["Q"](x)
            k = head["K"](x)
            v = head["V"](x)

            ### Now do the same Attention computation as before! ###
            similarity = (q @ k.transpose(1,2)) / (self.embed_dim ** 0.5)
            attention  = similarity.softmax(axis=-1)
            output = attention @ v

            ### Store this output in the head_outs ###
            head_outs.append(output)

        ### head_outs has num_heads tensors, each with the compressed embedding dimension of head_dim ###
        ### We can concatenate them all back together along the embedding dimension just like we did in the image above ###
        head_outs = torch.cat(head_outs, dim=-1)

        ### head_outs will have the same shape now as our input x! ###
        if head_outs.shape != x.shape:
            raise Exception("Something has gone wrong in the attention computation")

        ### Now each head was computed independently, we need them to get to know each other, so pass our head_outs through final projection ### 
        output = self.head_mesh(head_outs)

        return output
        

embed_dim = 9
num_heads = 3
seq_len = 8
mha = MultiHeadAttention(embed_dim, num_heads)

### Create a random tensor in the shape (Batch x Seq Len x Embed Dim) ###
rand = torch.randn(3,seq_len,embed_dim)

### Pass through MHA ###
output = mha(rand)


### Increasing Efficiency 

We now have a successful Multihead Attention layer!! This basically has all the same math and lodgic of attention, except for one small issue: efficiency. Typically we want to avoid for loops as much as possible in our PyTorch code, being able to vectorize and do things in parallel will make much better use of the GPUs we train on. To make this more efficient though, theres something we need to understand first: PyTorch Linear layers on multidimensional tensors!

#### Linear Layers on MultiDimensional Tensors

We have already seen `nn.Linear(input_dim, output_dim)` many times already, and this module expects a tensor of shape `[Batch x input_dim]` and it will output `[Batch x output_dim]`. But what if our input is `[Batch x Dim1 x Dim2 x input_dim]`, then what happens? Basically, PyTorch will automatically flatten all the dimensions other than the last one automagically, do the linear layer, and then return back to the expected shape, so we would get an output of `[Batch x Dim1 x Dim2 x output_dim]`. Another way of thinking about this is, PyTorch linear layers only are applied to the last dimension of your tensor. Lets do a quick example!

In [39]:
fc = nn.Linear(10,30)

tensor_1 = torch.randn(5,10)
tensor_1_out = fc(tensor_1)
print("Input Shape:", tensor_1.shape, "Output Shape:", tensor_1_out.shape)

tensor_2 = torch.randn(5,1,2,3,4,10)
tensor_2_out = fc(tensor_2)
print("Input Shape:", tensor_2.shape, "Output Shape:", tensor_2_out.shape)

Input Shape: torch.Size([5, 10]) Output Shape: torch.Size([5, 30])
Input Shape: torch.Size([5, 1, 2, 3, 4, 10]) Output Shape: torch.Size([5, 1, 2, 3, 4, 30])


### Packing Linear Layers

Another important idea is packing our linear layers together. Lets think about out example again, each projection for Q, K and V have a Linear layer that takes in 9 values and outputs 3 values, and we repeat this 3 times for each head. Lets just think about our Queries for now.

- Query for Head 1: Take in input x with embedding dim 9 and outputs tensor with embedding dimension 3
- Query for Head 2: Take in input x with embedding dim 9 and outputs tensor with embedding dimension 3
- Query for Head 3: Take in input x with embedding dim 9 and outputs tensor with embedding dimension 3

Well what if we reframed this? What if we had a single linear layer that take input x with embedding dim 9 and outputs something with embedding dim 9. Afterwards, we can cut the matrix into our three heads of attention. Lets do a quick example!

In [44]:
tensor = torch.randn(1,8,9)
fc = nn.Linear(9,9)

### Pass tensor through layer to make Queries ###
q = fc(tensor)
print("Shape of all Queries:", q.shape)

### Cut Embedding dimension into 3 heads ###
q_head1, q_head2, q_head3 = torch.chunk(q, 3, axis=-1)
print("Shape of each Head of Query:", q_head1.shape)


Shape of all Queries: torch.Size([1, 8, 9])
Shape of each Head of Query: torch.Size([1, 8, 3])


### MultiDimensional Matrix Multiplication

So, we have composed our 9 linear layers (3 heads have 3 projections for Q,k,V each) into just 3 linear layers, where we have packed all the heads into them. But after we chunk up our Q,K,V tensors each into three more tensors for each head we will still need to do the looping operation to go through the cooresponding q,k,v matricies. Can we parallelize this too? Of course! We just need to better understand higher dimensional matrix multiplication. 

#### Recap:

Matrix multiplication is typicall seen like this, multiplying an `[AxB]` matrix by a `[BxC]` which will produce a `[AxC]` matrix. But what if we have a `[Batch x dim1 x A x B]` multiplied by a `[Batch x dim1 x B x C]`. Matrix multiplication again only happens on the last two dimensions, so because our first tensor ends with an `[AxB]` and the second tensor ends with a `[BxC]`, the resulting matrix multiplication will be `[Batch x dim1 x A x C`]`. Lets see a quick example!

In [53]:
a = torch.randn(1,2,6,4)
b = torch.randn(1,2,4,3)
print("Final Output Shape:", (a@b).shape)

Final Output Shape: torch.Size([1, 2, 6, 3])


### The Trick of Parallelizing Heads

Now for the trick of parallelizing our heads by using everything we have just seen! All we need to do is split the embedding dimension up and move the heads out of the way so the computation can occur. Remember, we have our $Q$, $K$, and $V$ matricies right now that each contain all the projected heads and are in the shape `[Batch x Seq_len x Embed_dim]` ([batch x 8 x 9] in our case). 

- Step 1: Split the embedding dimension into the number of heads and head dim. We already know that our embedding dimension is divisible as thats how we set it, so we can do `[Batch x Seq_len x Embed_dim]` -> `[Batch x Seq_len x Num_Heads x Embed_Dim]`. (This would be taking our [batch x 8 x 9] and converting to [batch x 8 x 3 x 3])
- The attention computation has to happen between two matricies of shape `[Seq_len x Embed_Dim]` for queries and `[Embed_Dim x Seq_len]` for our transposed keys. In the case of multihead attention, the matrix multiplication happens across the Head Dimension rather than embedding. If our Queries, Keys and Values are in the shape `[Batch x Seq_len x Num_Heads x Embed_Dim]`, we can just transpose the Seq_len and Num_heads dimensions and make a tensor of shape `[Batch x Num_Heads x Seq_len x Embed_Dim]`. This way when I do Queries multiplied by the Transpose of Keys I will be doing `[Batch x Num_Heads x Seq_len x Embed_Dim]` multiplied by `[Batch x Num_Heads x Embed Dim x Seq Len]` creating the attention matrix `[Batch x Num_Heads x Seq_len x Seq Len]`. Therefore we have effectively created for every sample in the batch, and for every head of attention, a unique attention matrix! Thus we have parallelized the Attention Matrix computation.
- Now that we have out all our attention matricies `[Batch x Num_Heads x Seq_len x Seq Len]`, we can perform our scaling by our constant, and perform softmax across every row of each attention matrix (along the last dimension).
- The last step is to multiply out attention matrix `[Batch x Num_Heads x Seq_len x Seq Len]` by the Values which is in the shape `[Batch x Seq_len x Num_Heads x Embed_Dim]`, which will get us to `[Batch x Num_Heads x Seq_len x Embed_Dim]`!
- Lastly, we need to put our Num Heads and Embedding dimensions back together, so we can permute the Num Heads and Seq Len dimensions again which gives us `[Batch x Seq_len x Num_Heads x Embed_Dim]` and flatten on the last two dimension finally giving us `[Batch x Seq_len x Num_Heads x Embed_Dim]`.
- This flattening operation is equivalent to concatenation of all the heads of attention, so we can pass this through our final projection layer so all the heads of attention gets to know each other!

### Lets build Our Final Attention Implementation!

We will also include some extra dropout layers typically added to attention computations. 

In [59]:
class SelfAttentionEncoder(nn.Module):
    """
    Self Attention Proposed in `Attention is All  You Need` - https://arxiv.org/abs/1706.03762
    """
    
    def __init__(self,
               embed_dim=768,
               num_heads=12, 
               attn_p=0,
               proj_p=0):
        """
        
        Args:
            embed_dim: Transformer Embedding Dimension
            num_heads: Number of heads of computation for Attention 
            attn_p: Probability for Dropout2d on Attention cube
            proj_p: Probability for Dropout on final Projection
        """
        
        super(SelfAttentionEncoder, self).__init__()

        ### Make Sure Embed Dim is Divisible by Num Heads ###
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = int(embed_dim / num_heads)
        
        ### Define all our Projections ###
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(attn_p)

        ### Define Post Attention Projection ###
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)
        
    def forward(self, x):
        
        batch, seq_len, embed_dim = x.shape

        ### Compute Q, K, V Projections,and Reshape/Permute to [Batch x Num Heads x Seq Len x Head Dim] 
        q = self.q_proj(x).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(1,2).contiguous()
        k = self.k_proj(x).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(1,2).contiguous()
        v = self.v_proj(x).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(1,2).contiguous()
        
        ### Perform Attention Computation ###
        attn = (q @ k.transpose(-2,-1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v
        
        ### Bring Back to [Batch x Seq Len x Embed Dim] ###
        x = x.transpose(1,2).reshape(batch, seq_len, embed_dim)

        ### Pass through Projection so Heads get to know each other ###
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

embed_dim = 9
num_heads = 3
seq_len = 8
a = SelfAttentionEncoder(embed_dim, num_heads)

### Create a random tensor in the shape (Batch x Seq Len x Embed Dim) ###
rand = torch.randn(3,seq_len,embed_dim)

### Pass through MHA ###
output = a(rand)
print("Final Output:", output.shape)

Final Output: torch.Size([3, 8, 9])


### Enforcing Causality

Everything we have done so far is an encoding transformer, which is eqivalent to a bidirectional RNN where a single timestep can look forwards and backwards. To enforce causality, we can only look backwards, so we have to add in a causal mask! How the causal mask looks is the following:

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/causal_masking.png" width="800"/>
</div>

Basically, every Query vector can only attend to the Key vectors that are at the same timestep or before! Now how do we actually do this? Unfortunately its not as easy as just changing the numbers in our attention mask, because we still need every row of the attention mask to add to one (like a probability vector!). We want something like this though:

<div>
<img src="https://raw.githubusercontent.com/priyammaz/PyTorch-Adventures/main/src/visuals/decoder_attention_vis.png" width="800"/>
</div>

As we can see, because we have zeroed out the attention weights on future keys, when we multiply by values, we multiply by 0. Therefore our weighted average context vector of a timestep only is computed at that timestep and previous, never the future!


### Computing the Reweighted Causal Attention Mask

Lets pretend the raw outputs of $QK^T$, before the softmax, is below:

\begin{equation}
\begin{bmatrix}
  7       & -8   & 6  \\
  -3       & 2   & 4   \\
  1       & 6  & -2   \\
\end{bmatrix}
\end{equation}

Remember, the equation for softmax is:

$$\text{Softmax}(\vec{x}) = \frac{e^{x_i}}{\sum_{j=1}^N{e^{x_j}}}$$

Then, we can compute softmax for row of the matrix above:

\begin{equation}
\text{Softmax}
\begin{bmatrix}
  7       & -8   & 6  \\
  -3       & 2   & 4   \\
  1       & 6  & -2   \\
\end{bmatrix} = 
\begin{bmatrix}
  \frac{e^{7}}{e^{7}+e^{-8}+e^{6}}       & \frac{e^{-8}}{e^{7}+e^{-8}+e^{6}}   & \frac{e^{6}}{e^{7}+e^{-8}+e^{6}}  \\
  \frac{e^{-3}}{e^{-3}+e^{2}+e^{4}}       & \frac{e^{2}}{e^{-3}+e^{2}+e^{4}}   & \frac{e^{4}}{e^{-3}+e^{2}+e^{4}}  \\
  \frac{e^{1}}{e^{1}+e^{6}+e^{-2}}       & \frac{e^{6}}{e^{1}+e^{6}+e^{-2}}   & \frac{e^{-2}}{e^{1}+e^{6}+e^{-2}}  \\
\end{bmatrix} = 
\begin{bmatrix}
  0.73       & 0.0000002   & 0.27   \\
  0.0008       & 0.12   & 0.88 \\
  0.007       & 0.99  & 0.003  \\
\end{bmatrix}
\end{equation}

But, what we want, is the top triangle to have weights of 0, and the rest adding up to 1. So lets take the second vector in the matrix above to see how we can do that. 

$$x_2 = [-3, 2, 4]$$

Because this is the second vector, we need to zero out the softmax output for everything after the second index (so in our case just the last value). Lets replace the value 4 by $-\infty$. Then we can write it as:

$$x_2 = [-3, 2, -\infty]$$

Lets now take softmax of this vector!

$$\text{Softmax}(x_2) = [\frac{e^{-3}}{e^{-3}+e^{2}+e^{-\infty}}, \frac{e^{2}}{e^{-3}+e^{2}+e^{-\infty}}, \frac{e^{-\infty}}{e^{-3}+e^{2}+e^{-\infty}}]$$

Remember, $e^{-\infty}$ is equal to 0, so we can solve solve this!

$$\text{Softmax}(x_2) = [\frac{e^{-3}}{e^{-3}+e^{2}+0}, \frac{e^{2}}{e^{-3}+e^{2}+0}, \frac{0}{e^{-3}+e^{2}+0}] = [\frac{e^{-3}}{e^{-3}+e^{2}+0}, \frac{e^{2}}{e^{-3}+e^{2}+0}, \frac{0}{e^{-3}+e^{2}+0}] = [0.0067, 0.9933, 0.0000]$$

So we have exactly what we want! The attention weight of the last value is set to 0, so when we are on the second vector $x_2$, we cannot look forward to the future value vectors $v_3$, and the remaining parts add up to 1 so its still a probability vector! To do this correctly for the entire matrix, we can just substitute in the top triangle of $QK^T$ with $-\infty$. This would look like:

\begin{equation}
\begin{bmatrix}
  7       & -\infty   & -\infty  \\
  -3       & 2   & -\infty   \\
  1       & 6  & -2   \\
\end{bmatrix}
\end{equation}

Taking the softmax of the rows of this matrix then gives:

\begin{equation}
\text{Softmax}
\begin{bmatrix}
  7       & -\infty   & -\infty  \\
  -3       & 2   & -\infty   \\
  1       & 6  & -2   \\
\end{bmatrix} = 
\begin{bmatrix}
  1       & 0   & 0  \\
  0.0067  & 0.9933 & 0   \\
  0.007       & 0.99  & 0.003   \\
\end{bmatrix}
\end{equation}

Therefore, the best way to apply out attention mask is by filling the top right triangle with $-\inf$ and then take the softmax! So lets go ahead and add an option for causality for our attention function we wrote above!


In [64]:
class SelfAttention(nn.Module):
    """
    Self Attention Proposed in `Attention is All  You Need` - https://arxiv.org/abs/1706.03762
    """
    
    def __init__(self,
               embed_dim=768,
               num_heads=12, 
               attn_p=0,
               proj_p=0,
               causal=True, 
               seq_len=512):
        """
        
        Args:
            embed_dim: Transformer Embedding Dimension
            num_heads: Number of heads of computation for Attention 
            attn_p: Probability for Dropout2d on Attention cube
            proj_p: Probability for Dropout on final Projection
            causal: Do you want to apply a causal mask?
            seq_len: What is the max sequence length this model can expect?
        """
        
        super(SelfAttention, self).__init__()

        ### Make Sure Embed Dim is Divisible by Num Heads ###
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = int(embed_dim / num_heads)
        self.causal = causal
        
        ### Define all our Projections ###
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(attn_p)

        ### Define Post Attention Projection ###
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)

        ### Create NonLearnable Buffer to Act as our Causal Mask ###
        if causal:

            ### Create a Seq_Len x Seq_Len tensor full of Ones
            ones = torch.ones((seq_len, seq_len))

            ### Fill Top right triangle with Zeros (as we dont want to attend to them) ###
            causal_mask = torch.tril(ones) 

            ### Add extra dimensions for Batch size and Number of Heads ###
            causal_mask = causal_mask.reshape(1,1,seq_len,seq_len).bool()

            ### Store as a Buffer, as these parameters dont need to be learned ###
            self.register_buffer("causal_mask", causal_mask.to(torch.bool))
        
        
    def forward(self, x):
        
        batch, seq_len, embed_dim = x.shape

        ### Compute Q, K, V Projections,and Reshape/Permute to [Batch x Num Heads x Seq Len x Head Dim] 
        q = self.q_proj(x).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(1,2).contiguous()
        k = self.k_proj(x).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(1,2).contiguous()
        v = self.v_proj(x).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(1,2).contiguous()
        
        ### Perform Attention Computation ###
        attn = (q @ k.transpose(-2,-1)) * (self.head_dim ** -0.5)

        if self.causal:
            ####################################################################################
            ### FILL ATTENTION MASK WITH -Infinity ###
            attn = attn.masked_fill(self.causal_mask[:,:,:seq_len,:seq_len] == 0, float('-inf'))
            ####################################################################################
    
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v
        
        ### Bring Back to [Batch x Seq Len x Embed Dim] ###
        x = x.transpose(1,2).reshape(batch, seq_len, embed_dim)

        ### Pass through Projection so Heads get to know each other ###
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

embed_dim = 9
num_heads = 3
seq_len = 8
a = SelfAttention(embed_dim, num_heads, causal=True)

### Create a random tensor in the shape (Batch x Seq Len x Embed Dim) ###
rand = torch.randn(3,seq_len,embed_dim)

### Pass through MHA ###
output = a(rand)
print("Final Output:", output.shape)

Final Output: torch.Size([3, 8, 9])


### Thats It!

That is basically everything you need to know about attention! Now if you are actually going to be using attention in your model, it is better to use optimized cuda implementations rather than this only for speed and efficiency reasons, but they are doing exactly the same thing underneath the hood, just faster! Some examples of this are [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) or [FlashAttention](https://github.com/Dao-AILab/flash-attention) which are hardware aware and make better use of the GPUs to do these operations!