In [6]:
# Self attention from scratch using PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

The attention mechanism, often called the "attention node" or layer in Transformers, lets the model focus on relevant parts of the input sequence, like words in a sentence, instead of treating them equally. It replaces older sequential processing in models like RNNs, enabling faster parallel computation and better handling of long-range dependencies. This core idea from the 2017 "Attention is All You Need" paper powers models like GPT and BERT. [youtube](https://www.youtube.com/watch?v=KMHkbXzHn7s)

## Core Idea
Imagine reading a sentence: "The cat, which sat on the mat, ran away." When understanding "ran," the model needs to know it refers to the cat, not the mat. Attention scores how much each word (like "cat") relates to others (like "ran"), weighting their influence dynamically. [sciencedirect](https://www.sciencedirect.com/topics/computer-science/self-attention-mechanism)

## Query, Key, Value
Each input word turns into three vectors: Query (what I'm looking for), Key (what others offer), and Value (the actual info). For a word's query, it compares to all keys via dot product to get raw scores, then scales them (divide by sqrt of key size) to stabilize training. [youtube](https://www.youtube.com/watch?v=KMHkbXzHn7s)

## Computing Attention
Softmax turns scores into probabilities (weights summing to 1), then multiplies by values for a weighted output mix. This creates context-rich representations where important distant words contribute more. [geeksforgeeks](https://www.geeksforgeeks.org/nlp/transformer-attention-mechanism-in-nlp/)

## Multi-Head Attention
Instead of one attention pass, split into multiple "heads" (e.g., 8-16), each learning different relationships‚Äîlike syntax in one head, semantics in another. Outputs concatenate and project linearly for richer context. [youtube](https://www.youtube.com/watch?v=BvZS6PDUtD4)

In [8]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)  # (batch, time, channels)
print(x.shape)
print(x[:1])

## signle head self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B,T,head_size)
q = query(x)  # (B,T,head_size)
wei = q @ k.transpose(-2, -1)  # (B,T,head_size) @ (B,head_size,T) --> (B,T,T)


tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.softmax(wei, dim=-1)
out = wei @ x  # (T,T) @ (B,T,C) --> (B,T,C)

print(out.shape)


torch.Size([4, 8, 32])
tensor([[[ 1.8077e-01, -6.9988e-02, -3.5962e-01, -9.1520e-01,  6.2577e-01,
           2.5510e-02,  9.5451e-01,  6.4349e-02,  3.6115e-01,  1.1679e+00,
          -1.3499e+00, -5.1018e-01,  2.3596e-01, -2.3978e-01, -9.2111e-01,
           1.5433e+00,  1.3488e+00, -1.3964e-01,  2.8580e-01,  9.6512e-01,
          -2.0371e+00,  4.9314e-01,  1.4870e+00,  5.9103e-01,  1.2603e-01,
          -1.5627e+00, -1.1601e+00, -3.3484e-01,  4.4777e-01, -8.0164e-01,
           1.5236e+00,  2.5086e+00],
         [-6.6310e-01, -2.5128e-01,  1.0101e+00,  1.2155e-01,  1.5840e-01,
           1.1340e+00, -1.1539e+00, -2.9840e-01, -5.0754e-01, -9.2392e-01,
           5.4671e-01, -1.4948e+00, -1.2057e+00,  5.7182e-01, -5.9735e-01,
          -6.9368e-01,  1.6455e+00, -8.0299e-01,  1.3514e+00, -2.7592e-01,
          -1.5108e+00,  2.1048e+00,  2.7630e+00, -1.7465e+00,  1.4516e+00,
          -1.5103e+00,  8.2115e-01, -2.1153e-01,  7.7890e-01,  1.5333e+00,
           1.6097e+00, -4.0323e-01],
   

In [10]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

## Understanding Q, K, V - Simple Version

### Think of it Like a Library Search

Imagine you're in a library looking for books about "cats":

- **Query (Q)**: The question you ask the librarian: *"I want books about cats"*
- **Key (K)**: Labels on each book that describe what it's about: *"This book is about cats", "This book is about dogs", etc.*
- **Value (V)**: The actual content inside each book

### How Attention Works in 3 Steps

**Step 1: Compare (Q @ K)**
```
Your query "cats" matches against all book labels:
- Book 0 (about dogs): No match ‚Üí low score
- Book 3 (about animals): Medium match ‚Üí medium score  
- Book 6 (about cats): Perfect match ‚Üí HIGH score
- Book 7 (about pets): Good match ‚Üí HIGH score
```

**Step 2: Normalize (softmax)**
Convert scores into percentages that add up to 100%:
```
Book 0: 2.1%
Book 3: 22.97%
Book 6: 24.23%  ‚Üê Most relevant
Book 7: 23.91%  ‚Üê Very relevant
Others: ~9%
```

**Step 3: Get Information (weights @ V)**
Take the content from each book, weighted by relevance:
```
Final result = 2.1% of book 0's content 
             + 22.97% of book 3's content
             + 24.23% of book 6's content  ‚Üê gets most weight
             + 23.91% of book 7's content
             + ...
```

### Real Example from Your Code

Your attention weights: `[0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]`

This says: *"Position 6 is most relevant (24.23%), position 7 is very relevant (23.91%), position 3 is relevant (22.97%), but position 0 is barely relevant (2.1%)"*

### Why Are Some Values High and Others Low?

- **High (0.24, 0.23, 0.22)** = These tokens are similar in meaning
  - They contain information the current token needs
  - The model learned they're related

- **Low (0.021, 0.055)** = These tokens are different  
  - They don't help predict the next word
  - Not worth focusing on

### The Key Insight

Self-attention is like asking: *"Which previous words matter for understanding THIS word?"*

High attention weight ‚Üí "This word is very important for me"
Low attention weight ‚Üí "This word doesn't help me"

**That's how transformers understand context and relationships!** üéØ

- Attention is a communication mechanism