# Scaled Dot-Product Attention Walkthrough

## Problem Setup and Input Representation
We start with a token sequence $x = (x_1, x_2, \dots, x_T)$, where each token is embedded into $e_i \in \mathbb{R}^{d_{\text{model}}}$.  

Token embeddings come from a matrix $E \in \mathbb{R}^{|V| \times d_{\text{model}}}$ via $e_i = E[x_i]$. Because Transformers are non-recurrent, positional encodings inject order: $z_i^{(0)} = e_i + p_i$, where $p_i$ can be sinusoidal or learned.  

For sinusoidal encodings: $p_{i,2k} = \sin\left(\tfrac{i}{10000^{2k/d}}\right)$ and $p_{i,2k+1} = \cos\left(\tfrac{i}{10000^{2k/d}}\right)$. This notebook then walks through how those embeddings are processed by attention.

In [1]:
# Imports and utilities
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np

np.set_printoptions(precision=4, suppress=True)

def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    """Numerically-stable softmax."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def pretty(a: np.ndarray, name: str, decimals: int = 4) -> None:
    """Pretty-print a numpy array with a label."""
    np.set_printoptions(precision=decimals, suppress=True)
    print(f"\n{name} (shape={a.shape}):\n{a}")

## Core Attention Math
Container for intermediate tensors and the core operations.

In [2]:
@dataclass
class AttentionOutputs:
    """Container holding intermediate values useful for explanation/debugging."""

    Q: np.ndarray
    K: np.ndarray
    V: np.ndarray
    scores: np.ndarray
    attn_weights: np.ndarray
    output: np.ndarray


def make_qkv(
    X: np.ndarray,
    W_Q: np.ndarray,
    W_K: np.ndarray,
    W_V: np.ndarray,
 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Project input embeddings X to Q, K, V."""
    Q = X @ W_Q
    K = X @ W_K
    V = X @ W_V
    return Q, K, V


def apply_attention_mask(scores: np.ndarray, mask: Optional[np.ndarray]) -> np.ndarray:
    """Apply an additive mask to attention scores (1=mask)."""
    if mask is None:
        return scores

    if mask.shape != scores.shape:
        raise ValueError(f"Mask shape {mask.shape} must match scores shape {scores.shape}")

    neg_inf = -1e9
    masked_scores = scores + (mask.astype(np.float32) * neg_inf)
    return masked_scores


def scaled_dot_product_attention(
    Q: np.ndarray,
    K: np.ndarray,
    V: np.ndarray,
    mask: Optional[np.ndarray] = None,
) -> AttentionOutputs:
    """Compute scaled dot-product attention."""
    d_k = Q.shape[-1]
    scores = (Q @ K.T) / np.sqrt(d_k)
    scores_masked = apply_attention_mask(scores, mask)
    attn_weights = softmax(scores_masked, axis=-1)
    output = attn_weights @ V
    return AttentionOutputs(Q=Q, K=K, V=V, scores=scores_masked, attn_weights=attn_weights, output=output)


def causal_mask(T: int) -> np.ndarray:
    """Causal mask: disallow looking ahead (upper triangle)."""
    return np.triu(np.ones((T, T), dtype=np.int32), k=1)

## Demo Setup
Tiny 3-token example with human-readable numbers (d_model=4, d_k=2).

In [3]:
# Step 0: Example data and projection matrices
X = np.array(
    [
        [1.0, 0.0, 1.0, 0.0],  # token 1
        [0.0, 2.0, 0.0, 2.0],  # token 2
        [1.0, 1.0, 0.0, 0.0],  # token 3
    ],
    dtype=np.float32,
 )

W_Q = np.array(
    [
        [1.0, 0.0],
        [0.0, 1.0],
        [1.0, 0.0],
        [0.0, 1.0],
    ],
    dtype=np.float32,
 )

W_K = np.array(
    [
        [1.0, 1.0],
        [1.0, 0.0],
        [0.0, 1.0],
        [1.0, 0.0],
    ],
    dtype=np.float32,
 )

W_V = np.array(
    [
        [1.0, 0.0],
        [0.0, 1.0],
        [1.0, 1.0],
        [0.0, 1.0],
    ],
    dtype=np.float32,
 )

pretty(X, "Step 0: Input embeddings X")
pretty(W_Q, "W_Q")
pretty(W_K, "W_K")
pretty(W_V, "W_V")


Step 0: Input embeddings X (shape=(3, 4)):
[[1. 0. 1. 0.]
 [0. 2. 0. 2.]
 [1. 1. 0. 0.]]

W_Q (shape=(4, 2)):
[[1. 0.]
 [0. 1.]
 [1. 0.]
 [0. 1.]]

W_K (shape=(4, 2)):
[[1. 1.]
 [1. 0.]
 [0. 1.]
 [1. 0.]]

W_V (shape=(4, 2)):
[[1. 0.]
 [0. 1.]
 [1. 1.]
 [0. 1.]]


## Step 1 — Project to Q, K, V
Compute queries/keys/values via learned projections.

### Attention Math Used in Steps 1–3
For each position $i$ we project $z_i$ into $q_i = W_Q z_i$, $k_i = W_K z_i$, $v_i = W_V z_i$, with $W_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times d_k}$. Stacked: $Q = Z W_Q$, $K = Z W_K$, $V = Z W_V$.  

Scaled dot-product attention computes $A = \tfrac{Q K^\top}{\sqrt{d_k}}$, applies a row-wise softmax $\alpha = \text{softmax}(A)$, and produces $\text{Attention}(Q,K,V) = \alpha V$.  

For causal/decoder use, we mask future tokens so that $A_{ij} = \tfrac{q_i \cdot k_j}{\sqrt{d_k}}$ if $j \le i$ and $-\infty$ otherwise, ensuring $P(x_i \mid x_1, \dots, x_{i-1})$.

In [4]:
Q, K, V = make_qkv(X, W_Q, W_K, W_V)

pretty(Q, "Step 1: Queries Q = X @ W_Q")
pretty(K, "Step 1: Keys    K = X @ W_K")
pretty(V, "Step 1: Values  V = X @ W_V")


Step 1: Queries Q = X @ W_Q (shape=(3, 2)):
[[2. 0.]
 [0. 4.]
 [1. 1.]]

Step 1: Keys    K = X @ W_K (shape=(3, 2)):
[[1. 2.]
 [4. 0.]
 [2. 1.]]

Step 1: Values  V = X @ W_V (shape=(3, 2)):
[[2. 1.]
 [0. 4.]
 [1. 1.]]


## Step 2 — Compute Attention Scores
Calculate scaled dot products between queries and keys.

In [12]:
d_k = Q.shape[-1]
raw_scores = (Q @ K.T) / np.sqrt(d_k)
pretty(raw_scores, f"Step 2: Raw scores = (Q @ K^T) / sqrt(d_k={d_k})")
print("d_k: ", d_k)


Step 2: Raw scores = (Q @ K^T) / sqrt(d_k=2) (shape=(3, 3)):
[[1.4142 5.6569 2.8284]
 [5.6569 0.     2.8284]
 [2.1213 2.8284 2.1213]]
d_k:  2


## Step 3 — Apply Causal Mask
Disallow attending to future positions before softmax.

In [6]:
T = X.shape[0]
mask = causal_mask(T)
pretty(mask, "Step 3: Causal mask (1=masked/future, 0=allowed)")

masked_scores = apply_attention_mask(raw_scores, mask)
pretty(masked_scores, "Step 3: Masked scores (future positions -> very negative)")


Step 3: Causal mask (1=masked/future, 0=allowed) (shape=(3, 3)):
[[0 1 1]
 [0 0 1]
 [0 0 0]]

Step 3: Masked scores (future positions -> very negative) (shape=(3, 3)):
[[ 1.4142e+00 -1.0000e+09 -1.0000e+09]
 [ 5.6569e+00  0.0000e+00 -1.0000e+09]
 [ 2.1213e+00  2.8284e+00  2.1213e+00]]


## Step 4 — Softmax to Get Attention Weights
Convert masked logits into probabilities and verify rows sum to 1.

In [7]:
weights = softmax(masked_scores, axis=-1)
pretty(weights, "Step 4: Attention weights = softmax(masked_scores)")

row_sums = weights.sum(axis=-1)
pretty(row_sums, "(Check) Row sums of attention weights")


Step 4: Attention weights = softmax(masked_scores) (shape=(3, 3)):
[[1.     0.     0.    ]
 [0.9965 0.0035 0.    ]
 [0.2483 0.5035 0.2483]]

(Check) Row sums of attention weights (shape=(3,)):
[1. 1. 1.]


## Step 5 — Weighted Sum of Values
Multiply weights by values to produce the attended output.

In [8]:
output = weights @ V
pretty(output, "Step 5: Attention output = weights @ V")


Step 5: Attention output = weights @ V (shape=(3, 2)):
[[2.     1.    ]
 [1.993  1.0104]
 [0.7448 2.5105]]


## Step 6 — One-Call API Check
Call the packaged `scaled_dot_product_attention` and confirm it matches the manual steps.

## Beyond This Single-Head Demo
**Multi-Head Attention:**  
Use $h$ heads: $\text{head}_j = \text{Attention}(Q W_Q^{(j)}, K W_K^{(j)}, V W_V^{(j)})$, then $\text{MHA}(Z) = \text{Concat}(\text{head}_1,\dots,\text{head}_h) W_O$ with $W_O \in \mathbb{R}^{(h d_k) \times d_{\text{model}}}$. Multiple heads capture syntax, coreference, and long-range cues.  

**Feed-Forward Network (per position):**  
$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$, with $W_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}$, $W_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}$, typically $d_{\text{ff}} \approx 4 d_{\text{model}}$.  

**Residuals + LayerNorm:**  
Each block: $\tilde{Z} = \text{LayerNorm}(Z + \text{MHA}(Z))$ then $Z' = \text{LayerNorm}(\tilde{Z} + \text{FFN}(\tilde{Z}))$. Residuals aid gradients; LayerNorm stabilizes.  

**Full Stack:**  
For $L$ layers: $Z^{(l+1)} = \text{TransformerBlock}(Z^{(l)})$, starting from $Z^{(0)} = E(x) + P$, ending with $H = Z^{(L)}$.  

**Output Layer (LM):**  
Per time step: $\hat{y}_t = \text{softmax}(H_t W_{\text{out}} + b)$.  

**Training Objective (GPT-style):**  
Autoregressive MLE: $\mathcal{L} = - \sum_{t=1}^{T} \log P(x_t \mid x_1, \dots, x_{t-1})$ (cross-entropy). Even this simple objective yields learned syntax, semantics, reasoning, and instruction-following.  

**Why Scaling Works:**  
No recurrence → parallelism; attention gives global receptive field; adding depth/width/data yields emergent abilities—hence non-linear jumps from GPT-2 to GPT-4.

In [9]:
out = scaled_dot_product_attention(Q, K, V, mask=mask)

assert np.allclose(out.scores, masked_scores)
assert np.allclose(out.attn_weights, weights)
assert np.allclose(out.output, output)
print("\n✅ Done. The outputs from manual steps match the function output.")


✅ Done. The outputs from manual steps match the function output.


## Interpretation Example
Inspect one token's attention distribution over the sequence.

In [10]:
i = 2  # token 3 (0-indexed)
print("\nInterpretation example:")
print(f"Token {i+1} attention weights over tokens 1..{T}: {weights[i]}")
print("This means token 3 forms its representation as a weighted sum of V1, V2, V3.")


Interpretation example:
Token 3 attention weights over tokens 1..3: [0.2483 0.5035 0.2483]
This means token 3 forms its representation as a weighted sum of V1, V2, V3.
