In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"

        self.W_V = nn. Linear(self.embed_size, self.embed_size, bias=False)
        self.W_K = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.W_Q = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = self.W_V(values).reshape(N, value_len, self.heads, self.head_dim)
        keys = self.W_K(keys).reshape(N, key_len, self.heads, self.head_dim)
        queries = self.W_Q(query).reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        out = self.fc_out(out)
        return out, attention

# 示例使用
embed_size = 256
heads = 8
seq_length = 10
batch_size = 1

mha = MultiHeadAttention(embed_size, heads)
x = torch.randn(batch_size, seq_length, embed_size)
mask = torch.ones(batch_size, 1, seq_length, seq_length)

output, attention = mha(x, x, x, mask)

# 可视化不同头的注意力分布
plt.figure(figsize=(15, 8))
for i in range(heads):
    plt.subplot(2, 4, i+1)
    plt.imshow(attention[0, i].detach().numpy(), cmap='viridis')
    plt.title(f'Head {i+1}')
    plt.axis('off')
plt.tight_layout()
plt.show()

In [2]:
embed_size = 256
heads = 8
seq_length = 10
batch_size = 1

In [7]:
x = torch.randn(batch_size, seq_length, embed_size)
print(x.shape[0])  # 1
print(x.shape[1]) # 10
print(x.shape[2]) # 256

1
10
256


In [None]:
torch.ones(batch_size, 1, seq_length, seq_length)

In [10]:
W_V = nn.Linear(embed_size, embed_size, bias=False)
v = W_V(x)
print(v)
v.reshape(x.shape[0], x.shape[1], heads, embed_size // heads)

tensor([[[-0.0514,  0.5724,  0.7738,  ..., -0.6581,  1.0623, -0.5047],
         [-0.1851, -0.7963, -0.1730,  ..., -0.1564,  0.1908,  1.0888],
         [ 0.3316, -0.1187,  1.3206,  ..., -0.0506, -0.9239, -0.2433],
         ...,
         [ 0.7064, -0.4227,  0.4947,  ...,  0.0073, -0.2617, -0.2478],
         [-0.5140,  0.2559, -0.2099,  ...,  0.3654, -1.2443, -0.0594],
         [-0.3713, -0.0506,  0.5962,  ..., -0.8006,  0.2858, -0.4019]]],
       grad_fn=<UnsafeViewBackward0>)


tensor([[[[-5.1351e-02,  5.7236e-01,  7.7379e-01,  ..., -4.7329e-01,
           -3.5231e-01,  9.3953e-01],
          [-7.2955e-02, -8.7796e-02, -7.3498e-01,  ...,  2.0978e-01,
           -5.2034e-02,  5.1850e-01],
          [ 4.2092e-01, -3.1959e-01, -1.4015e+00,  ...,  2.6888e-01,
            9.2742e-01, -1.0992e+00],
          ...,
          [ 4.4064e-01,  1.9320e-01,  5.2276e-01,  ..., -6.8857e-01,
           -1.1678e+00,  8.3032e-01],
          [-2.5219e-01,  8.6043e-02, -1.7928e-01,  ..., -5.9504e-01,
            2.6388e-01,  1.0667e+00],
          [ 1.4543e-01, -7.2450e-01,  7.8694e-01,  ..., -6.5808e-01,
            1.0623e+00, -5.0473e-01]],

         [[-1.8512e-01, -7.9630e-01, -1.7301e-01,  ...,  2.7680e-01,
           -6.6447e-01,  4.4468e-01],
          [ 4.3197e-01, -5.0777e-01, -2.0941e-01,  ...,  1.4700e-01,
            5.7759e-01, -4.4414e-01],
          [ 1.1049e+00,  4.5713e-01, -1.1900e-03,  ..., -3.2168e-01,
            1.1543e+00,  3.1784e-01],
          ...,
     

In [24]:
import torch 
import torch.nn as nn

torch.manual_seed(0)

linear = nn.Linear(5, 4)
#linear.weight.data = torch.randn(4, 5)
print(linear.weight.shape)
print(linear.weight)

x = torch.tensor([1.0, 2.0, -1.0, 0.0, 3.0])
print(x.shape)
print(linear(x))

torch.Size([4, 5])
Parameter containing:
tensor([[-0.0033,  0.2399, -0.3681, -0.3291, -0.1722],
        [ 0.1199, -0.0089,  0.3546, -0.0397,  0.1183],
        [-0.1352, -0.0879, -0.4272, -0.2962, -0.1844],
        [ 0.0166,  0.1768,  0.2683, -0.3032, -0.1947]], requires_grad=True)
torch.Size([5])
tensor([ 0.4902,  0.4740, -0.5288, -0.1478], grad_fn=<ViewBackward0>)
