# Multi-Head Attention

In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

```
基础注意力机制 → Seq2Seq注意力 → 多头注意力 → Self-Attention → Transformer
     ↑              ↑              ↑           ↑            ↑
  点积/加性      用加性注意力     用点积注意力   自注意力    最终形态

Self-attention革命性的地方就在于打破了序列处理的限制，实现了真正的并行计算

```
MultiHeadAttention ⊃ Multi-head Self-attention
      通用框架        特殊应用

## Model

In [2]:
class MultiHeadAttention(d2l.Module): 
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads  # 专家数量（比如5个）
        self.attention = d2l.DotProductAttention(dropout)  # 基础注意力计算器
        
        # 为每个专家准备4套工具：
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)  # Query投影矩阵
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)  # Key投影矩阵  
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)  # Value投影矩阵
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)  # 输出投影矩阵

    def forward(self, queries, keys, values, valid_lens):
        # 第1步：给每个专家分配专门的 Q、K、V
        queries = self.transpose_qkv(self.W_q(queries))  # 分配查询
        keys = self.transpose_qkv(self.W_k(keys))        # 分配键  
        values = self.transpose_qkv(self.W_v(values))    # 分配值

        # 第2步：处理有效长度（每个专家都要知道句子的真实长度）
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)
            
        # 第3步：每个专家并行计算注意力
        output = self.attention(queries, keys, values, valid_lens)
        # 第4步：收集所有专家的成果并合并
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat) # 最终融合输出

为了能够使多个头并行计算， 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说，transpose_output函数反转了transpose_qkv函数的操作。

In [6]:
@d2l.add_to_class(MultiHeadAttention)  #@save
def transpose_qkv(self, X):
    """Transposition for parallel computation of multiple attention heads."""
    # Shape of input X: (batch_size, no. of queries or key-value pairs,
    # num_hiddens). Shape of output X: (batch_size, no. of queries or
    # key-value pairs, num_heads, num_hiddens / num_heads)
    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
    # Shape of output X: (batch_size, num_heads, no. of queries or key-value
    # pairs, num_hiddens / num_heads)
    X = X.permute(0, 2, 1, 3)
    # Shape of output: (batch_size * num_heads, no. of queries or key-value
    # pairs, num_hiddens / num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

@d2l.add_to_class(MultiHeadAttention)  #@save
def transpose_output(self, X):
    """Reverse the operation of transpose_qkv."""
    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

```
transpose_qkv的作用：
"大向量" → "多个小向量" → "便于并行计算"
(2,4,100) → (2,4,5,20) → (2,5,4,20) → (10,4,20)
 原始输入    分成5组      重排便于处理   10个并行任务

 ```
transpose_output的作用：
"多个小结果" → "大结果"
(10,4,20) → (2,5,4,20) → (2,4,5,20) → (2,4,100)
10个结果     重新分组      恢复维度      合并输出

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是（batch_size，num_queries，num_hiddens）。

In [7]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
                (batch_size, num_queries, num_hiddens))