```{contents}
```

## Self Attention & Cross Attention

### 1. Why Attention Exists (Core Intuition)

Before explaining self-attention and cross-attention, understand the **problem** they solve.

Neural Machine Translation (NMT) requires:

* Reading a source sentence (English)
* Understanding it as a whole
* Generating a target sentence (French)

RNNs and LSTMs struggled because:

* They compress the entire meaning of a sentence into a *single* hidden vector.
* Long sentences are hard to encode correctly.
* They process words sequentially, slowing down training.

**Attention** solved this by allowing the model to:

* Look back at specific words it needs,
* Weigh them differently depending on context,
* And process many words in parallel.

This idea became the foundation of Transformers.

---

### 2. What Q, K, V Represent (Intuitive View)

In all attention mechanisms, we project token embeddings into:

* **Query (Q)** → What I am looking for
* **Key (K)** → What information I offer
* **Value (V)** → The actual information content

Analogy:
Imagine researching in a library.

* Query = the question you're trying to answer
* Key = the index of each book
* Value = the content inside the book

Attention computes similarity between Query and Key, and uses that to decide how much of Value to read.

The formula:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

This is:

1. Compare Q with every K
2. Convert similarities into probabilities (softmax)
3. Blend the Value vectors using these probabilities

This blending produces **contextual embeddings**.

---

### 3. Self-Attention (Detailed, Practical Intuition)

Self-attention means:
**A token looks at every other token in the same sentence to understand its contextual meaning.**

All Q, K, V come from the *same* sentence.

#### Why do we need this?

Words change meaning based on context:

* *“bank”* (river bank vs monetary bank)
* *“trains”* (verb or noun)

Self-attention adjusts the embedding of each word based on the other words around it.

#### How it works in translation

Example:
“The boy **trains** the puppy.”

The raw word embedding of “trains” is ambiguous.
Self-attention allows “trains” to check:

* "boy" → subject
* "puppy" → object
* "the" → determiner

Because it sees these words, the layer learns that “trains” is a *verb*, not a noun.

**Effect:**
A new enriched, context-aware embedding of “trains” is produced.
This enriched embedding is what the encoder passes to the decoder.

#### Bidirectional vs Masked

* **Encoder self-attention**: Can look left and right (bidirectional).
* **Decoder self-attention**: Only looks left (causal mask), to prevent cheating by seeing future words when predicting.

---

### 4. Cross-Attention (Detailed, Practical Intuition)

Cross-attention connects the decoder with the encoder.

**The decoder uses its own Query, and attends to the encoder's Key and Value.**

#### Why?

When generating a translation, the decoder needs to look back at the source sentence.

Example:
Translating to French:

“The boy trains the puppy.” → “Le garçon entraîne le chiot.”

When the decoder is about to output the French equivalent of “trains”:

* Query = the decoder’s current hidden state
* Keys = encoder’s representation of each English word
* Values = same encoder representations

Cross-attention determines which source word is most relevant.

#### What happens internally?

Decoder asks:

> “Which English word should I focus on now?”

The attention score becomes highest for the source word “trains”.

So the decoder retrieves that part of the encoder's output and uses it to output the correct French verb form “entraîne”.

#### Why cross-attention is critical

Without cross-attention:

* Decoder would generate output blindly
* Translation quality would collapse
* Long-range dependencies would be lost

Cross-attention is a learnable lookup into encoder memory.

---

### 5. Putting Both Together (Full Translation Process)

#### Step 1: Encoder (Self-Attention)

The encoder reads the English sentence.

Self-attention refines each word:

* “trains” becomes a verb representation
* “boy” becomes a subject representation
* “puppy” becomes an object representation

It outputs a sequence of embeddings that represent the whole sentence meaningfully.

---

#### Step 2: Decoder (Masked Self-Attention)

When generating output token by token:

* The decoder uses masked self-attention to understand what it has generated so far.

---

#### Step 3: Cross-Attention (Connecting encoder and decoder)

At each decoding step:

* Decoder Q looks at encoder K, V
* Retrieves most relevant part of the source sentence
* Uses that to produce the next word

This is how alignment between languages emerges.

---

**Summary Table**

| Concept                      | Source of Q    | Source of K & V | Purpose                             |
| ---------------------------- | -------------- | --------------- | ----------------------------------- |
| **Self-attention (encoder)** | Encoder tokens | Encoder tokens  | Understand source sentence context  |
| **Self-attention (decoder)** | Decoder tokens | Decoder tokens  | Understand partial output so far    |
| **Cross-attention**          | Decoder tokens | Encoder tokens  | Link source meaning to output words |

---

**Most Important Intuition**

* **Self-attention helps each word understand its meaning by looking at surrounding words.**
* **Cross-attention helps the decoder retrieve the right source information at the right time.**


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----- Toy input: 5 words, embedding dim = 4 -----
x = torch.randn(5, 4)          # shape: (seq_len, embed_dim)

embed_dim = x.size(1)
d_k = embed_dim

# ----- Learnable projection matrices -----
W_q = nn.Linear(embed_dim, embed_dim, bias=False)
W_k = nn.Linear(embed_dim, embed_dim, bias=False)
W_v = nn.Linear(embed_dim, embed_dim, bias=False)

# ----- Compute Q, K, V -----
Q = W_q(x)   # (5,4)
K = W_k(x)   # (5,4)
V = W_v(x)   # (5,4)

# ----- Attention scores: QK^T -----
scores = Q @ K.transpose(0, 1) / (d_k ** 0.5)   # (5,5)

# ----- Softmax along keys dimension -----
attn_weights = F.softmax(scores, dim=-1)        # (5,5)

# ----- Weighted sum of Values -----
output = attn_weights @ V                      # (5,4)

print("Input embeddings:\n", x)
print("\nSelf-Attention weights:\n", attn_weights)
print("\nContextualized representations:\n", output)


Input embeddings:
 tensor([[ 0.4867, -0.7886, -0.3865, -0.2724],
        [-0.5073, -1.5010,  0.9716,  0.0105],
        [-1.3607, -0.7603,  0.0541,  0.3580],
        [ 0.1713, -0.1785,  0.4401,  1.4594],
        [ 1.6816, -0.4191,  0.7418,  0.0577]])

Self-Attention weights:
 tensor([[0.2117, 0.2329, 0.2196, 0.1565, 0.1792],
        [0.2151, 0.2241, 0.2617, 0.1501, 0.1490],
        [0.1838, 0.2075, 0.2144, 0.1938, 0.2005],
        [0.2282, 0.1933, 0.2238, 0.1869, 0.1679],
        [0.2539, 0.2129, 0.2398, 0.1509, 0.1424]], grad_fn=<SoftmaxBackward0>)

Contextualized representations:
 tensor([[ 0.1459, -0.2501,  0.0917, -0.2328],
        [ 0.1493, -0.3187,  0.1532, -0.2441],
        [ 0.1338, -0.2090,  0.0313, -0.2046],
        [ 0.1243, -0.2514,  0.0796, -0.1968],
        [ 0.1358, -0.3046,  0.1489, -0.2246]], grad_fn=<MmBackward0>)


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----- Encoder output: 5 tokens, emb_dim = 4 -----
encoder_out = torch.randn(5, 4)   # (src_seq_len, embed_dim)

# ----- Decoder input: 3 tokens -----
decoder_in = torch.randn(3, 4)    # (tgt_seq_len, embed_dim)

embed_dim = 4
d_k = embed_dim

# ----- Projection matrices -----
W_q = nn.Linear(embed_dim, embed_dim, bias=False)   # from decoder
W_k = nn.Linear(embed_dim, embed_dim, bias=False)   # from encoder
W_v = nn.Linear(embed_dim, embed_dim, bias=False)   # from encoder

# ----- Create Q (decoder) and K,V (encoder) -----
Q = W_q(decoder_in)           # (tgt_seq_len, embed_dim)
K = W_k(encoder_out)          # (src_seq_len, embed_dim)
V = W_v(encoder_out)          # (src_seq_len, embed_dim)

# ----- Compute attention scores -----
scores = Q @ K.transpose(0, 1) / (d_k ** 0.5)   # (tgt_seq_len, src_seq_len)

# ----- Convert to attention weights -----
attn_weights = F.softmax(scores, dim=-1)        # (tgt_seq_len, src_seq_len)

# ----- Compute contextualized decoder state -----
context = attn_weights @ V                      # (tgt_seq_len, embed_dim)

print("Decoder Q:\n", Q)
print("\nEncoder K:\n", K)
print("\nCross-attention weights:\n", attn_weights)
print("\nCross-attention context vectors:\n", context)


Decoder Q:
 tensor([[ 0.0743, -0.0228, -0.1231, -0.1664],
        [-0.4445,  0.1650,  0.5854,  0.3921],
        [-0.4303,  0.2577,  1.3873,  0.8675]], grad_fn=<MmBackward0>)

Encoder K:
 tensor([[ 0.1421, -0.3960,  0.2355, -0.1776],
        [ 0.2310,  0.1648,  0.5230,  0.1143],
        [-1.0206,  0.4274, -0.3964, -0.1257],
        [ 0.6489,  0.1754,  0.3166,  0.2640],
        [-0.1461, -0.2791, -0.0909, -0.1741]], grad_fn=<MmBackward0>)

Cross-attention weights:
 tensor([[0.2034, 0.1944, 0.1998, 0.1975, 0.2050],
        [0.1862, 0.2202, 0.2166, 0.1947, 0.1823],
        [0.1817, 0.2655, 0.1712, 0.2247, 0.1568]], grad_fn=<SoftmaxBackward0>)

Cross-attention context vectors:
 tensor([[ 0.0272,  0.0640, -0.1257,  0.0343],
        [ 0.0811,  0.0557, -0.1466,  0.0161],
        [ 0.1628,  0.0298, -0.2224,  0.0031]], grad_fn=<MmBackward0>)
