In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

<torch._C.Generator at 0x76d1b004ae70>

In [7]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embedding_size: int, heads: int):
        super(MultiHeadSelfAttention, self).__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // heads

        assert(
            self.head_dim * heads == embedding_size
        ), "Embedding Sieze 需要是heads的整数倍"
        # 线性变换用于生成Q、K、V矩阵
        self.values = nn.Linear(in_features=self.head_dim, out_features=self.head_dim, bias=False)
        self.keys = nn.Linear(in_features=self.head_dim, out_features=self.head_dim, bias=False)
        self.queies = nn.Linear(in_features=self.head_dim, out_features=self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embedding_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        # 分头计算Q、K、V矩阵
        values = values.view(N, value_len, self.heads, self.head_dim)
        keys = keys.view(N, key_len, self.heads, self.head_dim)
        queries = query.view(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queies(queries)
        # 计算Q与K的点积除以缩放因子
        energy = torch.einsum(
            "nqhd,nkhd->nhqk", [queries, keys]
        ) / (self.head_dim ** 0.5)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        # 计算注意力权重
        attention = torch.softmax(energy, dim=-1)
        # 注意力权重乘以V
        out = torch.einsum(
            "nhql,nlhd->nqhd", [attention, values]
        ).reshape(
            N, query_len, self.heads * self.head_dim
        )
        return self.fc_out(out)

In [8]:
# 超参数
embedding_size = 128
heads = 8
seq_length = 10
batch_size = 2

In [9]:
# 创建随机输入
values = torch.rand(batch_size, seq_length, embedding_size)
keys = torch.rand(batch_size, seq_length, embedding_size)
queries = torch.rand(batch_size, seq_length, embedding_size)

In [10]:
# 初始化自注意力层
self_attention_layer = MultiHeadSelfAttention(embedding_size, heads)

In [11]:
# 前向传播
output = self_attention_layer(values, keys, queries, mask=None)
print("输出的形状：", output.shape)  # 应该输出 torch.Size([2, 10, 128])
print("自注意力的输出：", output)

输出的形状： torch.Size([2, 10, 128])
自注意力的输出： tensor([[[-0.0785, -0.1186, -0.0190,  ..., -0.0991,  0.1303, -0.2674],
         [-0.0786, -0.1176, -0.0183,  ..., -0.0985,  0.1302, -0.2675],
         [-0.0783, -0.1177, -0.0180,  ..., -0.0991,  0.1306, -0.2683],
         ...,
         [-0.0778, -0.1172, -0.0187,  ..., -0.0975,  0.1294, -0.2678],
         [-0.0771, -0.1180, -0.0183,  ..., -0.0981,  0.1296, -0.2677],
         [-0.0776, -0.1186, -0.0188,  ..., -0.0989,  0.1297, -0.2675]],

        [[-0.1401, -0.1019,  0.0584,  ..., -0.0881,  0.1732, -0.2754],
         [-0.1425, -0.1028,  0.0614,  ..., -0.0898,  0.1736, -0.2751],
         [-0.1413, -0.1036,  0.0605,  ..., -0.0869,  0.1745, -0.2743],
         ...,
         [-0.1406, -0.1026,  0.0611,  ..., -0.0874,  0.1745, -0.2758],
         [-0.1405, -0.1027,  0.0588,  ..., -0.0880,  0.1739, -0.2755],
         [-0.1410, -0.1030,  0.0598,  ..., -0.0869,  0.1757, -0.2748]]],
       grad_fn=<ViewBackward0>)


In [12]:
# 注意力权重计算
class MultiHeadSelfAttentionWithWeights(MultiHeadSelfAttention):
    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        # 分头计算Q、K、V矩阵
        values = values.view(N, value_len, self.heads, self.head_dim)
        keys = keys.view(N, key_len, self.heads, self.head_dim)
        queries = query.view(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queies(queries)
        # 计算Q与K的点积除以缩放因子
        energy = torch.einsum(
            "nqhd,nkhd->nhqk", [queries, keys]
        ) / (self.head_dim ** 0.5)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        # 计算注意力权重
        attention = torch.softmax(energy, dim=-1)
        # 注意力权重乘以V
        out = torch.einsum(
            "nhql,nlhd->nqhd", [attention, values]
        ).reshape(
            N, query_len, self.heads * self.head_dim
        )
        return self.fc_out(out), attention  # 以上代码同SelfAttention类，只不过多返回注意力权重

In [13]:
# 初始化自注意力层
self_attention_layer = MultiHeadSelfAttentionWithWeights(embedding_size, heads)
# 前向传播
output, attention_weights = self_attention_layer(values, keys, queries, mask=None)
print("输出的形状：", output.shape)  # 应该输出 torch.Size([2, 10, 128])，与输入保持相同，利于层堆叠
print("自注意力的输出：", output)
print("注意力权重的形状：", attention_weights.shape)  # 应该输出 torch.Size([2, 8, 10, 10])
print("注意力权重：", attention_weights)

输出的形状： torch.Size([2, 10, 128])
自注意力的输出： tensor([[[ 0.1922, -0.1950,  0.1733,  ...,  0.2255,  0.1212, -0.0773],
         [ 0.1943, -0.1952,  0.1743,  ...,  0.2263,  0.1215, -0.0771],
         [ 0.1932, -0.1947,  0.1725,  ...,  0.2266,  0.1221, -0.0769],
         ...,
         [ 0.1928, -0.1929,  0.1728,  ...,  0.2259,  0.1231, -0.0769],
         [ 0.1951, -0.1950,  0.1735,  ...,  0.2264,  0.1248, -0.0794],
         [ 0.1941, -0.1933,  0.1739,  ...,  0.2258,  0.1211, -0.0783]],

        [[ 0.1323, -0.0954,  0.2114,  ...,  0.2396,  0.0258, -0.0900],
         [ 0.1307, -0.0982,  0.2128,  ...,  0.2388,  0.0257, -0.0892],
         [ 0.1314, -0.0977,  0.2126,  ...,  0.2375,  0.0265, -0.0898],
         ...,
         [ 0.1310, -0.0961,  0.2101,  ...,  0.2384,  0.0250, -0.0900],
         [ 0.1319, -0.0962,  0.2115,  ...,  0.2381,  0.0262, -0.0896],
         [ 0.1315, -0.0954,  0.2109,  ...,  0.2403,  0.0252, -0.0905]]],
       grad_fn=<ViewBackward0>)
注意力权重的形状： torch.Size([2, 8, 10, 10])
注意力权重：