In [1]:
import torch
import torch.nn.functional as F

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.query_linear = torch.nn.Linear(embed_dim, embed_dim)
        self.key_linear = torch.nn.Linear(embed_dim, embed_dim)
        self.value_linear = torch.nn.Linear(embed_dim, embed_dim)
        self.out_linear = torch.nn.Linear(embed_dim, embed_dim)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 线性变换
        Q = self.query_linear(query)  # (batch_size, seq_len, embed_dim)
        K = self.key_linear(key)      # (batch_size, seq_len, embed_dim)
        V = self.value_linear(value)  # (batch_size, seq_len, embed_dim)
        
        # 分割为多个头，并进行转置
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # 拼接所有头的输出
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        
        # 最后的线性变换
        output = self.out_linear(attn_output)
        
        return output, attn_weights
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        
        # 应用掩码
        if mask is not None:
            scores += mask
        
        # 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)
        
        # 计算注意力输出
        attn_output = torch.matmul(attn_weights, V)
        
        return attn_output, attn_weights

# 示例使用
batch_size = 2
seq_len = 10
embed_dim = 64
num_heads = 8

query = torch.rand(batch_size, seq_len, embed_dim)
key = torch.rand(batch_size, seq_len, embed_dim)
value = torch.rand(batch_size, seq_len, embed_dim)

multi_head_attn = MultiHeadAttention(embed_dim, num_heads)

# 定义优化器
optimizer = torch.optim.Adam(multi_head_attn.parameters(), lr=0.001)

# 假设我们有一些目标标签
target = torch.rand(batch_size, seq_len, embed_dim)

# 训练循环
for epoch in range(100):
    # 前向传播
    output, attn_weights = multi_head_attn(query, key, value)
    
    # 计算损失
    loss_fn = torch.nn.MSELoss()
    loss = loss_fn(output, target)
    
    # 清零梯度
    optimizer.zero_grad()
    
    # 反向传播
    loss.backward()
    
    # 优化参数
    optimizer.step()
    
    print(f"Epoch {epoch}, Loss: {loss.item()}")


Epoch 0, Loss: 0.37541353702545166
Epoch 1, Loss: 0.34171804785728455
Epoch 2, Loss: 0.31143707036972046
Epoch 3, Loss: 0.28395161032676697
Epoch 4, Loss: 0.2588028311729431
Epoch 5, Loss: 0.23568689823150635
Epoch 6, Loss: 0.2144397497177124
Epoch 7, Loss: 0.19500429928302765
Epoch 8, Loss: 0.17739588022232056
Epoch 9, Loss: 0.1616668999195099
Epoch 10, Loss: 0.14786039292812347
Epoch 11, Loss: 0.13595713675022125
Epoch 12, Loss: 0.12584929168224335
Epoch 13, Loss: 0.11736734211444855
Epoch 14, Loss: 0.11033036559820175
Epoch 15, Loss: 0.10457366704940796
Epoch 16, Loss: 0.09994816780090332
Epoch 17, Loss: 0.09630793333053589
Epoch 18, Loss: 0.09350096434354782
Epoch 19, Loss: 0.09136834740638733
Epoch 20, Loss: 0.08975155651569366
Epoch 21, Loss: 0.08850657939910889
Epoch 22, Loss: 0.08751830458641052
Epoch 23, Loss: 0.08670497685670853
Epoch 24, Loss: 0.0860128253698349
Epoch 25, Loss: 0.08540843427181244
Epoch 26, Loss: 0.08487293124198914
Epoch 27, Loss: 0.08439697325229645
Epoch 