# 多头注意力机制

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义一个 MultiHeadAttention 类，它继承自 nn.Module
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        # 调用父类的构造函数
        super().__init__()
        # 保存模型维度和头数
        self.d_model = d_model
        self.d_k = d_model // heads  # 每个头对应的维度
        self.h = heads  # 头的数量

        # 初始化线性层，用于将输入转换为查询（Q）、键（K）和值（V）
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        # 初始化Dropout层，用于正则化
        self.dropout = nn.Dropout(dropout)
        # 初始化输出线性层，用于将多头注意力输出转换为模型维度
        self.out = nn.Linear(d_model, d_model)

    # 定义注意力机制的计算过程
    def attention(self, q, k, v, mask=None):
        # 计算Q和K的矩阵乘积，然后除以根号下d_k
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        # 如果提供了掩码，则将掩码对应的位置设置为负无穷，这样在softmax后这些位置的值为0
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        # 应用softmax函数获得注意力权重
        scores = F.softmax(scores, dim=-1)
        # 应用dropout
        scores = self.dropout(scores)
        # 将注意力权重和V相乘得到输出
        output = torch.matmul(scores, v)
        return output

    # 定义前向传播过程
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        # 将输入Q、K、V通过线性层，并调整形状以进行多头注意力计算
        q = self.q_linear(q).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        k = self.k_linear(k).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        v = self.v_linear(v).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        # 计算注意力输出
        scores = self.attention(q, k, v, mask)
        # 将多头输出合并，并调整形状以匹配模型维度
        concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        # 通过输出线性层
        output = self.out(concat)
        return output

# 主函数，用于测试 MultiHeadAttention 类
if __name__ == "__main__":
    # 初始化模型参数
    heads = 4
    d_model = 128  # d_model应该是heads的整数倍。
    dropout = 0.1

    # 创建 MultiHeadAttention 实例
    model = MultiHeadAttention(heads, d_model, dropout)

    # 创建随机数据作为输入
    batch_size = 2
    seq_len = 5
    q = torch.rand(batch_size, seq_len, d_model)  # Query
    k = torch.rand(batch_size, seq_len, d_model)  # Key
    v = torch.rand(batch_size, seq_len, d_model)  # Value

    # 执行前向传播
    output = model(q, k, v)

    # 打印输出形状，应该是 [batch_size, seq_len, d_model]
    print("Output shape:", output.shape)

    # 检查模型是否可以进行反向传播
    loss = output.mean()  # 创建一个简单的损失函数
    loss.backward()  # 执行反向传播
    print("Backward pass completed.")  # 如果没有错误，则表示反向传播成功


Output shape: torch.Size([2, 5, 128])
Backward pass completed.
