# Coding attention mechanisms

<img src="../images/figure-3.1-three-main-stages-of-llm-chapter-3-focus-on-stage-1-step-2.webp" width="800px">

- We will look at <span style="color:#4ea9fb"><b>attention mechanism</b> in isolation and focus on them at mechanistic level</span>.
- We will implement <b>4 different variants of attention mechanisms</b>:
  - <span style="color:#4ea9fb"><b>Simplified self-attention</b></span>
  - <span style="color:#4ea9fb"><b>Self-attention with trainable weights</b></span>
  - <span style="color:#4ea9fb"><b>Casual attention</b></span> 
    - Adds mask to self-attention that allows the models to only consider previous and current inputs in a sequence.
  - <span style="color:#4ea9fb"><b>Multi-head attention</b></span>
    - Organizes attention mechanism into multiple heads.

<img src="../images/figure-3.2-four-different-variants-of-attention-mechanism.webp" width="700px">

## 3.1 The problem with modeling long sequences

Say, we are in <b>pre-LLM era</b>, and we want to <b>develop a language translation model</b>.
- <span style="color:red">We cannot simply translate a text word by word due to the <b>grammatial structures and contextual understanding</b> of the source and target language.</span>
- To address this problem, DNNs generally use two submodules, 
  - <span style="color:#4ea9fb"><b>encoder</b></span> (first, read and process the entire text) and 
  - <span style="color:#4ea9fb"><b>decoder</b></span> (then produces the translated text).

<img src="../images/figure-3.3-german-to-english-problem-with-word-for-word-translation.webp" width="700px">

<b>What's RNN, and why they were popular before transformers?</b>
- Before the advent of transformers, <span style="color:green"><i><b>recurrent neural networks </i>(RNNs) were the most popular encoder-decoder architecture for language translation</b></span>.
- <span style="color:#4ea9fb">RNN is a type of NN where outputs from the previous step are fed as inputs to the current step, making them suitable for sequential data like text.</span>

<b>What does the RNN (encoder-decoder) do?</b>
- <span style="color:#4ea9fb">The encoder processes a sequence of words/tokens from source language as input, using a hidden state, an intermediate neural network layer of the encoder&mdash;to generate a condensed (encoded) representation of the entire input sequence.</span>
- <span style="color:#4ea9fb">The decoder then uses this encoded representation (hidden state) to generate the translated text, one word at a time (i.e., token by token).</span>

<p style="color:black; background-color:#F5C780; padding:15px">💡 <b>Key idea of encoder-decoder RNNs</b><br><b>- Encoder</b>: Processes the entire input text into hidden state (memory cell).<br><b>- Decoder</b>: Takes in this hidden state to produce the output, one word at a time.<br><b>- Hidden state:</b> ~ Similar to Embedding vector in Chapter 2.</span></p>

<b>What's the problem with encoder-decoder RNNs?</b>
- <span style="color:red">RNNs have a hard time capturing long-range dependencies in the complex sentences.</span>
  - RNN cannot directly access earlier hidden state from the encoder during the decoding phase.
  - Consequently, the decoder relies solely on the current hidden state, which despite encapsulating all relevant information, may not be sufficient to generate the correct translation.
  - This leads to loss of context.
    - Although RNNs work fine for short sentences, <span style="color:red">they struggle with longer sentences as they don't have direct access to previous words in the input sequence.</span>
- <span style="color:#4ea9fb">This motivated the design of attention mechanisms.</span>

<img src="../images/figure-3.4-german-to-english-translation-using-RNNs-encoder-decoder.webp" width="700px">

## 3.2 Capturing data dependencies with attention mechanisms

<b>Why attention mechanisms?</b>
- <span style="color:red">One major shortcomings of above RNNs is that it must remember the entire encoded input in a single hidden state before passing it to the decoder (figure 3.4).</span>
- <span style="color:#4ea9fb">Attention mechanisms address this issue by allowing the decoder to focus  on (i.e., selectively access) different parts of the input sequence at each decoding step, implying that <b>certain input tokens hold more significance than others in the generation of a specific output token</b> (figure 3.5).</span>

<p style="color:black; background-color:#F5C780; padding:15px">💡 Interestingly, <b>3 years after researchers developed <i>attention</i> mechanism for RNN, they found that RNN architectures are not required for DNN for NLP and proposed original <i>transformer</i> architecture.</span></p>

<img src="../images/figure-3.5-german-to-english-translation-using-RNNs-encoder-decoder-with-attention-mechanism.webp" width="700px">

<b>Why self-attention?</b>
- <span style="color:#4ea9fb">It allows each position in the input sequence to <b>"attend to"</b> (i.e., compute relevancy of) all positions in the input sequence when computing the representation of the sequence.</span>
- <span style="color:#4ea9fb">Self-attention is a <b>key component of contemporary LLMs based on the transformer architecture</b>, such as BERT, GPT-2, and T5.</span>

<img src="../images/figure-3.6-self-attention-mechanism-topic-of-the-current-chapter.webp" width="500px">

## 3.3 Attending to different part of the input with self-attention

<p style="color:black; background-color:#F5C780; padding:15px">💡❗ Once we <b>grasp the fundamentals of self-attention</b>, we would have <b>conquered one of the toughest aspects of this book and LLM implementation</b> in general.</span></p>

<p style="color:black; background-color:#F5C780; padding:15px">💡<b>The "self" in self-attention.</b><br>
- <span style="color:#4ea9fb">"Self" in in self-attention refers to the mechanim's <b>ability to compute attention weights by relating different positions within a single input sequence</b></span>.<br>
- It assess and <b>learns the relationships and dependencies between various parts of the input itself</b>, <i>such as words in a sentence</i> or <i>pixels in an image</i>.<br><br>
- <span style="color:red">This is in contrast to traditional attention mechanisms, where the focus is on the relationship between two different sequences, such as in sequence-to-sequence models (for machine translation) where the attention might be between an source (input) and target (output) sentences </span>(e.g., figure 3.5).
</p>

### 3.3.1 A simple self-attention mechanism without trainable weights

<p style="color:black; background-color:#F5C780; padding:15px">⚠️ This section is purely for illustration purposes and NOT the attention mechanism used in transformers.</p>

```mermaid
graph LR
    A[Input embeddings] --> B[Attention scores]
    B --> C[Attention weights]
    C --> D[Context vectors]
```

[![](https://mermaid.ink/img/pako:eNpVz80OgjAMAOBXWXqWF-BgInAx8aQ3nYfJKiy6lozOnxDe3YEhxp768zVpB6jZIuTQBNO1arfXpFJsThq21EVR6C9oraOm13BWWbZWRZptRJDEMam-5oDT7LtYzKT8I090TSs_U86mSqZkEnyJemAtHCYBK_AYvHE23TRMXoO06FFDnlJrwk2DpjE5E4UPb6ohlxBxBYFj0y5F7KwRrJxJf_ml2Rk6Mqfyau49jh-9N1KY?type=png)](https://mermaid.live/edit#pako:eNpVz80OgjAMAOBXWXqWF-BgInAx8aQ3nYfJKiy6lozOnxDe3YEhxp768zVpB6jZIuTQBNO1arfXpFJsThq21EVR6C9oraOm13BWWbZWRZptRJDEMam-5oDT7LtYzKT8I090TSs_U86mSqZkEnyJemAtHCYBK_AYvHE23TRMXoO06FFDnlJrwk2DpjE5E4UPb6ohlxBxBYFj0y5F7KwRrJxJf_ml2Rk6Mqfyau49jh-9N1KY)

$$x^{(i)} \longrightarrow \omega^{(i)} \longrightarrow \alpha^{(i)} \longrightarrow z^{(i)}$$

- Input sequence, $x^{(1)}$ to $x^{(T)}$ 
  - The input sequence is a text (for e.g., <i>"Your journey starts with one step"</i>) that has already been converted into token embeddings. 
    - For instance, $x^{(1)} = [0.4, 0.1, 0.8]$ is a $\text{d}$-dimensional (3-dimensional) vector that represents the word <i>"Your"</i>. 
- Goal: In self-attention, our goal is to <span style="color:#4ea9fb"><b>compute context vectors (a.k.a. enriched embedding vector) $z^{(i)}$  for each element $x^{(i)}$ in the input sequence</b> by incorporating information from all other elements in the sequence</span> (figure 3.7).
  - <span style="color:#4ea9fb">Context vectors play a crucial role in self-attention</span>.
  - <span style="color:#4ea9fb">A context vector</span> $z^{(i)}$ <span style="color:#4ea9fb"> is a weighted sum of all inputs </span> , $x^{(1)}$ to $x^{(T)}$</span>.
  - For e.g., context vector $z^{(2)}$ is an embedding that contains information about $x^{(2)}$ and all other input elements, $x^{(1)}$ to $x^{(T)}$.

<p style="color:black; background-color:#F5C780; padding:15px">⚠️ Below formulas and table clearly explains how different components are computed.</p>

- <span style="color:#4ea9fb">Query</span>, $x^{(2)}$
- <span style="color:#4ea9fb">Attention score</span>, $\omega^{(ij)} = x^{(i)} \cdot  x^{(j)}$
  - $\implies \omega^{(21)} = x^{(2)} \cdot  x^{(1)}$
  -  $\implies \omega^{(2)} = \left[ \omega^{(21)} \quad \omega^{(22)} \quad \omega^{(23)} \quad \omega^{(26)} \right] = \left[ x^{(2)} \cdot  x^{(1)} \quad  x^{(2)} \cdot  x^{(2)} \quad  x^{(2)} \cdot  x^{(3)} \quad  \dots \quad  x^{(2)} \cdot  x^{(6)} \right]$    
- <span style="color:#4ea9fb">Attention weights</span>, $\alpha^{(i)} = \text{softmax}(\omega^{(i)})$
  -  $\implies \alpha^{(2)} = \left[ \alpha^{(21)} \quad \alpha^{(22)} \quad \alpha^{(23)} \quad \dots \quad \alpha^{(26)} \right] = \text{softmax}(\omega^{(2)}) $
-  <span style="color:#4ea9fb">Context vector</span>, $z^{(i)} = \sum_{j=1}^{T} \alpha^{(ij)} x^{(j)}$
   -  $\implies z^{(2)} = \sum_{j=1}^{T} \alpha^{(2j)} x^{(j)} = \alpha^{(21)} x^{(1)} + \alpha^{(22)} x^{(2)} + \alpha^{(23)} x^{(3)} + \dots + \alpha^{(26)} x^{(6)}$

| Notation | Shape | Dimensions 
| :-----------: |:------------: | ------------|
| $x$        | 6 x 3       |    2D tensor with 6 rows and 3 columns |
| $x^{(2)}$        | 3       | 1D tensor with 3 elements |
| $\omega^{(2)}$        | 6       | 1D tensor with 6 elements |
| $\alpha^{(2)}$        | 6       | 1D tensor with 6 elements |
| $\alpha^{(2j)}$        | 1       | 1D tensor with 1 element |
| $z^{(2)}$        | 3       | 1D tensor with 3 elements |

<img src="../images/figure-3.7-goal-of-self-attention-compute-context-vector.webp" width="800px">

<p style="color:black; background-color:#F5C780; padding:15px">📝 Let's implement a simplified self-attention mechanism to compute the weights and resulting context vector, step-by-step.</p>

In the below example, 
- `context_length = 6`
- `embed_size = 3`

In [5]:
import torch

### Input token embeddings ###
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your    (x^1)
        [0.55, 0.87, 0.66],  # journey (x^2)
        [0.57, 0.85, 0.64],  # starts  (x^3)
        [0.22, 0.58, 0.33],  # with    (x^4)
        [0.77, 0.25, 0.10],  # one     (x^5)
        [0.05, 0.80, 0.55],  # step    (x^6)
    ]
)

<img src="../images/figure-3.8-computation-of-context-vector-illustration.webp" width="700px">

In [54]:
### Input embeddings to Attention scores ###
query = inputs[1]  # "journey" (x^2)    second input token embedding
print(f"query: {query}")
attn_scores_2 = torch.empty(inputs.shape[0])
print(f"\ninput idx | input_element.input_query:\t\t\t\t\t\t| attn_scores_2")
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
    print(f"{i:2d}        | {x_i}.{query}   | {attn_scores_2[i]:.3f}")
print(f"\nattn_scores_2 = {attn_scores_2}")

query: tensor([0.5500, 0.8700, 0.6600])

input idx | input_element.input_query:						| attn_scores_2
 0        | tensor([0.4300, 0.1500, 0.8900]).tensor([0.5500, 0.8700, 0.6600])   | 0.954
 1        | tensor([0.5500, 0.8700, 0.6600]).tensor([0.5500, 0.8700, 0.6600])   | 1.495
 2        | tensor([0.5700, 0.8500, 0.6400]).tensor([0.5500, 0.8700, 0.6600])   | 1.475
 3        | tensor([0.2200, 0.5800, 0.3300]).tensor([0.5500, 0.8700, 0.6600])   | 0.843
 4        | tensor([0.7700, 0.2500, 0.1000]).tensor([0.5500, 0.8700, 0.6600])   | 0.707
 5        | tensor([0.0500, 0.8000, 0.5500]).tensor([0.5500, 0.8700, 0.6600])   | 1.087

attn_scores_2 = tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


<p style="color:black; background-color:#F5C780; padding:15px">💡 <b>Understanding dot products</b><br>
- It's a concise way of multiplying two vectors element-wise and then summing their products (i.e., sum of element-wise multiplication)<br><br>
- It's also a measure of similarity between two vectors, where higher dot product implies higher similarity, and vice versa.<br>
- <span style="color:green">In the context of self-attention mechanisms, <b>the dot product determines the extent to which each element in a sequence "attends to" (focuses on) any other element</b>; the higher the dot product, the higher the similarity and attention score between two elements, thus more attention is paid to that element.</span>
</p>

<b>Why normalization of attention scores, $\omega^{(i)}$?</b>
- To obtain the attention weigths that sum up to 1.
- Useful for interpretation and maintaining training stability in an LLM.

In [79]:
### Attention scores to Attention weights ###
print(f"Attention scores to Attention weights (Rudimentary approach)")
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print(f"attn_weights_2 = {attn_weights_2_tmp}")
print(f"Sum: {attn_weights_2_tmp.sum():.3f}")


print(f"\nAttention scores to Attention weights (Softmax)")


def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)


attn_weights_2_naive = softmax_naive(attn_scores_2)
print(f"attn_weights_2 = {attn_weights_2_naive}")
print(f"Sum: {attn_weights_2_naive.sum():.3f}")

print(f"\nAttention scores to Attention weights (Softmax) - PyTorch")
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print(f"attn_weights_2 = {attn_weights_2}")
print(f"Sum: {attn_weights_2.sum():.3f}")

Attention scores to Attention weights (Rudimentary approach)
attn_weights_2 = tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: 1.000

Attention scores to Attention weights (Softmax)
attn_weights_2 = tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.000

Attention scores to Attention weights (Softmax) - PyTorch
attn_weights_2 = tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.000


In [82]:
### Attention weights to Context vector ###
query = inputs[1]  # "journey" (x^2)    second input token embedding
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2_naive[i] * x_i
print(f"\ncontext_vec_2 = {context_vec_2}")


context_vec_2 = tensor([0.4419, 0.6515, 0.5683])


<p style="color:black; background-color:#F5C780; padding:15px">💡 <b>Use softmax function for normalization</b><br>
- Softmax function is better at managing extreme values and offers more favorable gradient properties during training.<br>
- <b><span style="color:green">Softmax function ensures that the attention weights are always positive</span>, <span style="color:red">unlike simple normalization</span></b>.<br>
  - This <b>makes the output interpretable as probabilities, where higher weights indicate greater importance</b>.<br><br>
💡 <b>Why use PyTorch implementation of softmax?</b><br>
- <span style="color:red">Naive softmax implementation can lead to numerical instability, especially when dealing with large or small input values</span>.<br>
- <span style="color:green">PyTorch implementaion of softmax has been extensively tested and optimized for numerical stability.</span>
</p>


In [77]:
# Normalization without and with softmax
attn_scores_temp = torch.tensor([1.0, -0.5, 1.5])
attn_weights_tmp = attn_scores_temp / attn_scores_temp.sum()
print(f"[Simple normalization]        attn_weights_temp = {attn_weights_tmp}")

attn_weights_tmp = attn_scores_temp / attn_scores_temp.sum()
attn_weights_tmp = softmax_naive(attn_scores_temp)
print(f"[Softmax based normalization] attn_weights_temp = {attn_weights_tmp}")

[Simple normalization]        attn_weights_temp = tensor([ 0.5000, -0.2500,  0.7500])
[Softmax based normalization] attn_weights_temp = tensor([0.3482, 0.0777, 0.5741])


<img src="../images/figure-3.9-obtain-attention-weights-example.webp" width="800px">


<img src="../images/figure-3.10-compute-context-vector-example.webp" width="700px">


<img src="../images/figure-3.11-attention-weights-heatmap-example.webp" width="700px">


<img src="../images/figure-3.12-three-step-process-to-compute-context-vectors.webp" width="800px">


<img src="../images/figure-3.7-goal-of-self-attention-compute-context-vector.webp" width="800px">