# Lecture 6 Tutorial: Attention 与 LLaMA 内部解析

本 notebook 对应 2025 版第 6 讲的课堂内容，通过 PyTorch 演示注意力模块的关键步骤，帮助在课堂上连贯展示：
- 复现标准注意力的逐步计算
- 观察 Multi-Head 架构的张量变换
- 演示 Grouped Query Attention (GQA) 对 KV 缓存的压缩
- 用简化的 Blocked Attention 框架链接理论与实现


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(42)
torch.set_printoptions(precision=4, sci_mode=False)
print(f"PyTorch version: {torch.__version__}")


PyTorch version: 2.8.0


## 1. 从 hidden states 得到 Q/K/V

在 LLaMA 的 decoder layer 中，`q_proj`/`k_proj`/`v_proj` 是线性层，负责把上一层的 hidden states 变换为查询、键和值张量。下面我们用一个可视化规模的示例来模拟这个过程。


In [2]:
batch_size, seq_len = 2, 4
num_heads = 4
head_dim = 8
hidden_size = num_heads * head_dim

hidden_states = torch.randn(batch_size, seq_len, hidden_size)

q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
k_proj_full = nn.Linear(hidden_size, hidden_size, bias=False)
v_proj_full = nn.Linear(hidden_size, hidden_size, bias=False)
o_proj = nn.Linear(hidden_size, hidden_size, bias=False)

q = q_proj(hidden_states)
k = k_proj_full(hidden_states)
v = v_proj_full(hidden_states)

print(f"hidden_states shape: {hidden_states.shape}")
print(f"q shape: {q.shape}, k shape: {k.shape}, v shape: {v.shape}")


hidden_states shape: torch.Size([2, 4, 32])
q shape: torch.Size([2, 4, 32]), k shape: torch.Size([2, 4, 32]), v shape: torch.Size([2, 4, 32])


## 2. 单头注意力 (Scaled Dot-Product)

下面按照讲义的推导，逐步计算 $P$、$\text{softmax}$ 与输出 $O$。为简洁起见，这里只展示因果 mask（decoder 的自回归约束）。


In [11]:
scale = 1.0 / math.sqrt(hidden_size)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
scores_masked = scores.masked_fill(causal_mask, float('-inf'))
attn_weights = F.softmax(scores_masked, dim=-1)
context_single = torch.matmul(attn_weights, v)

print('scores shape:', scores.shape)
print('attention weights sums:', attn_weights.sum(dim=-1))
print('context_single shape:', context_single.shape)
print(causal_mask)
print(scores_masked)

scores shape: torch.Size([2, 4, 4])
attention weights sums: tensor([[1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)
context_single shape: torch.Size([2, 4, 32])
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])
tensor([[[ 0.1864,    -inf,    -inf,    -inf],
         [-0.2687, -0.1780,    -inf,    -inf],
         [ 0.2653, -0.3983,  0.2157,    -inf],
         [-0.1443, -0.0368,  0.2042,  0.0041]],

        [[ 0.0577,    -inf,    -inf,    -inf],
         [ 0.1112,  0.0621,    -inf,    -inf],
         [ 0.0612,  0.0685, -0.3430,    -inf],
         [-0.5373, -0.0201, -0.4858,  0.3036]]], grad_fn=<MaskedFillBackward0>)


In [None]:
def stable_softmax(x: torch.Tensor) -> torch.Tensor:
    x_max = torch.nan_to_num(x.max(dim=-1, keepdim=True).values)
    x_exp = torch.exp(x - x_max)
    return x_exp / x_exp.sum(dim=-1, keepdim=True)


manual_weights = stable_softmax(scores_masked)
print('manual == torch.softmax?', torch.allclose(manual_weights, attn_weights, atol=1e-6))


tensor([[[ 0.1864],
         [-0.1780],
         [ 0.2653],
         [ 0.2042]],

        [[ 0.0577],
         [ 0.1112],
         [ 0.0685],
         [ 0.3036]]], grad_fn=<NanToNumBackward0>)
manual == torch.softmax? True


## 3. Multi-Head Attention 的张量变换

把 $Q/K/V$ 拆成多个头可以让模型在不同的子空间里捕捉特征。注意 reshape、transpose 的顺序与 `Concat` 聚合回原维度的细节。


In [5]:
q_heads = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
k_heads = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
v_heads = v.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)

print('q_heads shape:', q_heads.shape)

causal_mask_h = causal_mask.unsqueeze(0).unsqueeze(0)
scores_multi = torch.matmul(q_heads, k_heads.transpose(-2, -1)) / math.sqrt(head_dim)
scores_multi = scores_multi.masked_fill(causal_mask_h, float('-inf'))
weights_multi = F.softmax(scores_multi, dim=-1)
context_multi = torch.matmul(weights_multi, v_heads)
context_multi = context_multi.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
output_multi = o_proj(context_multi)

print('context_multi shape:', context_multi.shape)
print('output_multi shape:', output_multi.shape)


q_heads shape: torch.Size([2, 4, 4, 8])
context_multi shape: torch.Size([2, 4, 32])
output_multi shape: torch.Size([2, 4, 32])


## 4. Grouped Query Attention (GQA)

LLaMA-2/3 等模型把 Query 头分成若干组，共享更少数量的 Key/Value 投影，从而减少 KV cache。下面的示例展示了 KV 张量的尺寸如何缩小，并通过 repeat 把它扩展回 Query 头数以完成计算。


In [6]:
num_key_value_heads = 2
kv_proj_dim = num_key_value_heads * head_dim

k_proj_gqa = nn.Linear(hidden_size, kv_proj_dim, bias=False)
v_proj_gqa = nn.Linear(hidden_size, kv_proj_dim, bias=False)

k_gqa_raw = k_proj_gqa(hidden_states)
v_gqa_raw = v_proj_gqa(hidden_states)

k_kv = k_gqa_raw.view(batch_size, seq_len, num_key_value_heads, head_dim).transpose(1, 2).contiguous()
v_kv = v_gqa_raw.view(batch_size, seq_len, num_key_value_heads, head_dim).transpose(1, 2).contiguous()


def repeat_kv(hidden_states: torch.Tensor, target_heads: int) -> torch.Tensor:
    repeat_factor = target_heads // hidden_states.size(1)
    return hidden_states.repeat_interleave(repeat_factor, dim=1)


k_gqa_expanded = repeat_kv(k_kv, num_heads)
v_gqa_expanded = repeat_kv(v_kv, num_heads)

scores_gqa = torch.matmul(q_heads, k_gqa_expanded.transpose(-2, -1)) / math.sqrt(head_dim)
scores_gqa = scores_gqa.masked_fill(causal_mask_h, float('-inf'))
weights_gqa = F.softmax(scores_gqa, dim=-1)
context_gqa = torch.matmul(weights_gqa, v_gqa_expanded)
context_gqa = context_gqa.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)

print('k_heads (full) shape:', k_heads.shape)
print('k_kv (GQA cached) shape:', k_kv.shape)
print(f'Full KV cache elements: {k_heads.numel():,}')
print(f'GQA KV cache elements: {k_kv.numel():,}')
print('context_gqa shape:', context_gqa.shape)


k_heads (full) shape: torch.Size([2, 4, 4, 8])
k_kv (GQA cached) shape: torch.Size([2, 2, 4, 8])
Full KV cache elements: 256
GQA KV cache elements: 128
context_gqa shape: torch.Size([2, 4, 32])


## 5. Blocked Attention (简化版)

BlockedAttention 通过把序列切成块来局部化 QK 计算，从而更好地适配 GPU 并行与缓存。下面用 Python 循环搭建一个教学用的原型，帮助理解讲义里的流程图。


In [7]:
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
    dim = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim)
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, v)


def blocked_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_size: int, mask: torch.Tensor | None = None) -> torch.Tensor:
    bs, seq_len, dim = q.shape
    outputs = torch.zeros_like(q)

    for start in range(0, seq_len, block_size):
        end = min(start + block_size, seq_len)
        q_block = q[:, start:end, :]

        scores_blocks = []
        value_blocks = []
        for k_start in range(0, seq_len, block_size):
            k_end = min(k_start + block_size, seq_len)
            k_block = k[:, k_start:k_end, :]
            scores = torch.matmul(q_block, k_block.transpose(-2, -1)) / math.sqrt(dim)
            if mask is not None:
                mask_block = mask[:, start:end, k_start:k_end]
                scores = scores.masked_fill(mask_block, float('-inf'))
            scores_blocks.append(scores)
            value_blocks.append(v[:, k_start:k_end, :])

        scores_concat = torch.cat(scores_blocks, dim=-1)
        weights = F.softmax(scores_concat, dim=-1)
        values_concat = torch.cat(value_blocks, dim=1)
        output_block = torch.matmul(weights, values_concat)
        outputs[:, start:end, :] = output_block

    return outputs


In [8]:
mask = causal_mask.unsqueeze(0).expand(batch_size, -1, -1)

full_attention = scaled_dot_product_attention(q, k, v, mask)
blocked_attention_out = blocked_attention(q, k, v, block_size=2, mask=mask)

print('full_attention shape:', full_attention.shape)
print('blocked_attention shape:', blocked_attention_out.shape)
print('max diff:', (full_attention - blocked_attention_out).abs().max())


full_attention shape: torch.Size([2, 4, 32])
blocked_attention shape: torch.Size([2, 4, 32])
max diff: tensor(    0.0000, grad_fn=<MaxBackward1>)


## 6. 课堂提示

- 把每个代码块与讲义页码对应起来，课堂上可以逐段执行并解释张量尺寸的变化。
- 如果要深入 FlashAttention，可以在此基础上展示 CUDA kernel 或调用 `torch.nn.functional.scaled_dot_product_attention` 的对比。
- GQA 的 repeat/expand 逻辑与 Hugging Face `LlamaAttention` 中的实现一致，可对照源码进一步解读。
