## The Components:

- **Query ($Q$):** What I am looking for. (e.g., "The cat sat on the...")
- **Key ($K$):** The label or "tag" of the information stored. (e.g., "mat")
- **Value ($V$):** The actual content/information. (e.g., The vector embedding for "mat")

## The Process:

1. **Similarity:** Compare $Q$ with every $K$ (Dot Product).
2. **Probability:** Turn scores into percentages (Softmax).
3. **Retrieval:** Multiply percentages by $V$.

## The Math

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

- **$QK^T$:** The Dot Product. Measures alignment. High score = High similarity.
- **$\sqrt{d_k}$:** The Scaling Factor. (We will verify why this is needed below).
- **$\text{Softmax}$:** Normalizes scores so they sum to 1.0

In [1]:
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Computes the scaled dot product attention.
    Args:
        q: Queries (Batch, Heads, Seq_Len_Q, Dim_Head)
        k: Keys    (Batch, Heads, Seq_Len_K, Dim_Head)
        v: Values  (Batch, Heads, Seq_Len_K, Dim_Head)
        mask: Optional mask (e.g., for causal masking)
    """
    d_k = q.size(-1) # Dimension of the key head
    
    # 1. Similarity Scores (Q @ K_transpose)
    # Shape: (Batch, Heads, Seq_Len_Q, Seq_Len_K)
    # We transpose the last two dimensions of K to align for multiplication
    scores = torch.matmul(q, k.transpose(-2, -1))
    
    # 2. Scaling (The Stability Key)
    scores = scores / math.sqrt(d_k)
    
    # 3. Masking (Optional - we will use this later)
    if mask is not None:
        # Replace 0s (masked positions) with -infinity so softmax makes them 0
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 4. Attention Weights (Softmax)
    # Convert scores to probabilities (0.0 to 1.0)
    attn_weights = F.softmax(scores, dim=-1)
    
    # 5. Weighted Aggregation (Weights @ V)
    output = torch.matmul(attn_weights, v)
    
    return output, attn_weights

# --- VERIFICATION LAB ---
def verify_attention():
    print("--- 2.1 Attention Math Check ---")
    
    # Setup standard dimensions
    batch_size = 1
    heads = 1
    seq_len = 3
    d_k = 4 # Tiny dimension for easy reading
    
    # Create fake data
    # Query: Looking for something specific
    q = torch.randn(batch_size, heads, seq_len, d_k)
    k = torch.randn(batch_size, heads, seq_len, d_k)
    v = torch.randn(batch_size, heads, seq_len, d_k)
    
    output, weights = scaled_dot_product_attention(q, k, v)
    
    print(f"Scores Shape: {weights.shape}")
    print(f"Output Shape: {output.shape}")
    print("\nAttention Weights (Row sum must be 1.0):")
    print(weights[0,0]) 
    print("Row Sum:", weights[0,0].sum(dim=-1))

if __name__ == "__main__":
    verify_attention()

--- 2.1 Attention Math Check ---
Scores Shape: torch.Size([1, 1, 3, 3])
Output Shape: torch.Size([1, 1, 3, 4])

Attention Weights (Row sum must be 1.0):
tensor([[0.2053, 0.3439, 0.4508],
        [0.2956, 0.3737, 0.3308],
        [0.5547, 0.2591, 0.1862]])
Row Sum: tensor([1.0000, 1.0000, 1.0000])


## Why Scaling? (The "Hot" Softmax Problem)

You asked why we divide by $\sqrt{d_k}$. Let's use simple numbers.

### The Softmax Function:

Softmax turns numbers into probabilities.

- **Input:** [2, 1] → **Softmax:** [0.73, 0.27] (Nice, soft mix)
- **Input:** [20, 10] → **Softmax:** [0.99995, 0.00005] (Extreme, hard spike)

### The Problem with High Dimensions:

In deep learning, our vectors are long (e.g., 512 numbers).

When you do a Dot Product of two long vectors, the result is a **Sum**.

$1 \times 1 + 1 \times 1 + ...$ (512 times) = 512.

The numbers get huge just because the vectors are long.

If we feed huge numbers (like 512) into Softmax, it panics. It outputs 1.0 for the winner and 0.0 for everyone else.

**Result:** The model becomes "arrogant." It only looks at one thing and ignores everything else. The gradients die.

### The Fix (Scaling):

We divide the huge number by the square root of the length ($\sqrt{512} \approx 22$).

$512 / 22 \approx 23$. (Still big, but manageable).

Usually, it brings the variance back down to ~1.0.

In [3]:
import torch
import torch.nn.functional as F
import math

def simple_scaling_demo():
    print("--- Why Scaling Matters ---")
    
    # 1. Imagine a Dot Product result (Similarity Scores)
    # Let's say we have 3 keys.
    
    # CASE A: Small numbers (Simulating a small network)
    small_scores = torch.tensor([2.0, 1.0, 0.5])
    print(f"\n1. Small Scores: {small_scores.tolist()}")
    
    # Apply Softmax
    probs_small = F.softmax(small_scores, dim=0)
    print(f"   Softmax Output: {probs_small.tolist()}")
    print("   -> Result: Nice distribution. The model considers all options.")

    # CASE B: Huge numbers (Simulating a real Transformer without scaling)
    # In 512-dim space, dot products naturally become huge (e.g., 50, 100)
    huge_scores = torch.tensor([200.0, 100.0, 50.0])
    print(f"\n2. Huge Scores (Unscaled): {huge_scores.tolist()}")
    
    # Apply Softmax
    probs_huge = F.softmax(huge_scores, dim=0)
    print(f"   Softmax Output: {probs_huge.tolist()}")
    print("   -> Result: [1.0, 0.0, 0.0]. The model is 'arrogant'. Gradient is DEAD.")
    
    # CASE C: The Fix (Scaling)
    # We divide by a factor (let's say 100 to simulate sqrt(d_k))
    scale_factor = 100.0
    scaled_scores = huge_scores / scale_factor
    print(f"\n3. Scaled Scores: {scaled_scores.tolist()}")
    
    # Apply Softmax
    probs_scaled = F.softmax(scaled_scores, dim=0)
    print(f"   Softmax Output: {probs_scaled.tolist()}")
    print("   -> Result: Back to a nice distribution! Gradients can flow.")

if __name__ == "__main__":
    simple_scaling_demo()

--- Why Scaling Matters ---

1. Small Scores: [2.0, 1.0, 0.5]
   Softmax Output: [0.6285316944122314, 0.23122389614582062, 0.14024437963962555]
   -> Result: Nice distribution. The model considers all options.

2. Huge Scores (Unscaled): [200.0, 100.0, 50.0]
   Softmax Output: [1.0, 3.783505853677006e-44, 0.0]
   -> Result: [1.0, 0.0, 0.0]. The model is 'arrogant'. Gradient is DEAD.

3. Scaled Scores: [2.0, 1.0, 0.5]
   Softmax Output: [0.6285316944122314, 0.23122389614582062, 0.14024437963962555]
   -> Result: Back to a nice distribution! Gradients can flow.
