In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"

        self.W_V = nn. Linear(self.embed_size, self.embed_size, bias=False)
        self.W_K = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.W_Q = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = self.W_V(values).reshape(N, value_len, self.heads, self.head_dim)
        keys = self.W_K(keys).reshape(N, key_len, self.heads, self.head_dim)
        queries = self.W_Q(query).reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        out = self.fc_out(out)
        return out, attention

# 示例使用
embed_size = 256
heads = 8
seq_length = 10
batch_size = 1

mha = MultiHeadAttention(embed_size, heads)
x = torch.randn(batch_size, seq_length, embed_size)
mask = torch.ones(batch_size, 1, seq_length, seq_length)

output, attention = mha(x, x, x, mask)

# 可视化不同头的注意力分布
plt.figure(figsize=(15, 8))
for i in range(heads):
    plt.subplot(2, 4, i+1)
    plt.imshow(attention[0, i].detach().numpy(), cmap='viridis')
    plt.title(f'Head {i+1}')
    plt.axis('off')
plt.tight_layout()
plt.show()

In [14]:
batch_size = 1
seq_length = 5
embed_size = 6
heads = 2

In [None]:
torch.manual_seed(0)

linear = nn.Linear(5, 4)
#linear.weight.data = torch.randn(4, 5)
print(linear.weight.shape)
print(linear.weight)

x = torch.tensor([1.0, 2.0, -1.0, 0.0, 3.0])
print(x.shape)
print(linear(x))

In [15]:
x = torch.randn(batch_size, seq_length, embed_size)
print(x.shape[0])
print(x.shape[1])
print(x.shape[2])

1
5
6


In [16]:
torch.ones(batch_size, 1, seq_length, seq_length)

tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]]])

In [17]:
W_V = nn.Linear(embed_size, embed_size, bias=False)
v = W_V(x)
print(v)
print(v.shape)
v = v.reshape(x.shape[0], x.shape[1], heads, embed_size // heads)
print(v)
print(v.shape)

tensor([[[ 0.4418,  0.2320, -0.6752, -0.3111,  0.5736, -0.5544],
         [ 0.5012,  0.5117, -1.6920, -0.9251, -0.2557, -1.0512],
         [ 0.8153, -0.3434, -0.4991, -0.2057, -0.0357, -0.5482],
         [-0.6561,  0.2849,  0.7725,  0.2777,  0.5643,  0.3041],
         [-0.5946, -0.1228,  0.6577,  0.1047, -0.6076,  0.8863]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 5, 6])
tensor([[[[ 0.4418,  0.2320, -0.6752],
          [-0.3111,  0.5736, -0.5544]],

         [[ 0.5012,  0.5117, -1.6920],
          [-0.9251, -0.2557, -1.0512]],

         [[ 0.8153, -0.3434, -0.4991],
          [-0.2057, -0.0357, -0.5482]],

         [[-0.6561,  0.2849,  0.7725],
          [ 0.2777,  0.5643,  0.3041]],

         [[-0.5946, -0.1228,  0.6577],
          [ 0.1047, -0.6076,  0.8863]]]], grad_fn=<ViewBackward0>)
torch.Size([1, 5, 2, 3])


In [25]:
torch.matmul(v[0,0,0], v[0,0,0])
torch.dot(v[0,0,0], v[0,0,0])

tensor(0.7050, grad_fn=<DotBackward0>)

In [31]:
energy = torch.einsum("nqhd,nkhd->nhqk", [v, v])
print(energy)
print(energy.shape)

tensor([[[[ 0.7050,  1.4826,  0.6176, -0.7454, -0.7353],
          [ 1.4826,  3.3759,  1.0774, -1.4901, -1.4737],
          [ 0.6176,  1.0774,  1.0317, -1.0183, -0.7709],
          [-0.7454, -1.4901, -1.0183,  1.1084,  0.8632],
          [-0.7353, -1.4737, -0.7709,  0.8632,  0.8012]],

         [[ 0.7332,  0.7239,  0.3474,  0.0687, -0.8725],
          [ 0.7239,  2.0262,  0.7756, -0.7209, -0.8732],
          [ 0.3474,  0.7756,  0.3441, -0.2440, -0.4857],
          [ 0.0687, -0.7209, -0.2440,  0.4880, -0.0442],
          [-0.8725, -0.8732, -0.4857, -0.0442,  1.1657]]]],
       grad_fn=<ViewBackward0>)
torch.Size([1, 2, 5, 5])


In [30]:
attention = torch.softmax(energy / (embed_size ** (1/2)), dim=3)
print(attention)
print(attention.shape)

tensor([[[[0.2249, 0.3089, 0.2170, 0.1244, 0.1249],
          [0.2169, 0.4699, 0.1838, 0.0645, 0.0649],
          [0.2237, 0.2699, 0.2649, 0.1147, 0.1269],
          [0.1494, 0.1103, 0.1337, 0.3185, 0.2882],
          [0.1534, 0.1135, 0.1512, 0.2946, 0.2873]],

         [[0.2420, 0.2411, 0.2067, 0.1845, 0.1256],
          [0.2084, 0.3546, 0.2128, 0.1155, 0.1086],
          [0.2134, 0.2541, 0.2131, 0.1676, 0.1518],
          [0.2107, 0.1526, 0.1854, 0.2500, 0.2012],
          [0.1455, 0.1455, 0.1704, 0.2041, 0.3345]]]],
       grad_fn=<SoftmaxBackward0>)
torch.Size([1, 2, 5, 5])


In [32]:
torch.einsum("nhql,nlhd->nqhd", [attention, v])

tensor([[[[ 0.2752,  0.1558, -0.6045],
          [-0.2764,  0.0975, -0.3335]],

         [[ 0.4003,  0.2380, -0.9408],
          [-0.3932,  0.0205, -0.4737]],

         [[ 0.2993,  0.1161, -0.5678],
          [-0.2828,  0.0521, -0.3167]],

         [[-0.1501,  0.1005,  0.0814],
          [-0.1544,  0.0940, -0.1246]],

         [[-0.1162,  0.0904,  0.0455],
          [-0.1232, -0.0479,  0.0314]]]], grad_fn=<ViewBackward0>)