# 1. 简单自注意力

In [1]:
import torch
import torch.nn.functional as F

x = torch.randn(2, 3, 4)
raw_weights = torch.bmm(x, x.transpose(1, 2))
attn_weights = F.softmax(raw_weights, dim=2)
attn_outputs = torch.bmm(attn_weights, x)
print(attn_outputs)

tensor([[[ 0.2478, -0.5643,  0.6200,  0.2377],
         [ 0.2983, -0.8605,  0.8004,  0.3084],
         [ 0.3253,  0.2624,  0.0781,  0.0707]],

        [[-0.1451,  0.4304,  0.4068, -1.3426],
         [-0.6821, -0.1725, -1.1295,  0.8898],
         [-1.4510, -0.3237,  0.5355, -0.1539]]])


# 2. 标准自注意力

In [2]:
x = torch.randn(2, 3, 4)
linear_q = torch.nn.Linear(4, 4)
linear_k = torch.nn.Linear(4, 4)
linear_v = torch.nn.Linear(4, 4)

Q = linear_q(x)
K = linear_k(x)
V = linear_v(x)

raw_weights = torch.bmm(Q, K.transpose(1, 2))
print(raw_weights)

scale_factor = K.size(-1)**0.5
scale_weights = raw_weights / scale_factor
print(scale_weights)

attn_weights = F.softmax(scale_weights, dim=2)
attn_outputs = torch.bmm(attn_weights, V)
print(f"加权自注意力: {attn_outputs}")

tensor([[[-0.2342,  0.1263, -0.3857],
         [ 0.5356,  0.1055,  0.8644],
         [ 0.7031,  0.8428,  0.7917]],

        [[-0.6063,  0.7535,  0.4628],
         [-0.0733,  1.0862,  0.8365],
         [ 0.2903, -0.4481, -0.4775]]], grad_fn=<BmmBackward0>)
tensor([[[-0.1171,  0.0631, -0.1929],
         [ 0.2678,  0.0527,  0.4322],
         [ 0.3515,  0.4214,  0.3958]],

        [[-0.3032,  0.3767,  0.2314],
         [-0.0366,  0.5431,  0.4183],
         [ 0.1451, -0.2240, -0.2387]]], grad_fn=<DivBackward0>)
加权自注意力: tensor([[[ 0.3119,  0.4292,  0.0073,  0.5235],
         [ 0.5184,  0.4523, -0.0223,  0.5037],
         [ 0.3849,  0.4396, -0.0026,  0.5178]],

        [[ 0.0177,  0.0955, -0.0859,  0.6672],
         [ 0.0196,  0.0938, -0.0775,  0.6586],
         [ 0.0174,  0.0821,  0.0317,  0.5703]]], grad_fn=<BmmBackward0>)


# 3. 多头自注意力

In [3]:
import torch
import torch.nn.functional as F

x = torch.randn(2, 3, 4)

num_heads = 2
head_dim = 2

assert x.size(-1) == num_heads * head_dim

linear_q = torch.nn.Linear(4, 4)
linear_k = torch.nn.Linear(4, 4)
linear_v = torch.nn.Linear(4, 4)

Q = linear_q(x)
K = linear_k(x)
V = linear_v(x)

def split_heads(tensor, num_heads):
    batch_size, seq_len, feature_dim = tensor.size()
    head_dim = feature_dim // num_heads
    output = tensor.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
    return output

Q = split_heads(Q, num_heads)
K = split_heads(K, num_heads)
V = split_heads(V, num_heads)

raw_weights = torch.matmul(Q, K.transpose(-2, -1))
print(raw_weights)

scale_factor = K.size(-1)**0.5
scale_weights = raw_weights / scale_factor
print(scale_weights)

attn_weights = F.softmax(scale_weights, dim=-1)
attn_outputs = torch.matmul(attn_weights, V)

def combine_heads(tensor):
    batch_size, num_heads, seq_len, head_dim = tensor.size()
    feature_dim = num_heads * head_dim
    output = tensor.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)
    return output

attn_outputs = combine_heads(attn_outputs)

linear_out = torch.nn.Linear(4, 4)
attn_outputs = linear_out(attn_outputs)

print(f"多头自注意力: {attn_outputs}")



tensor([[[[ 0.0499, -0.1048, -0.1749],
          [ 0.1599, -0.2591, -0.4439],
          [ 0.0460, -0.0812, -0.1379]],

         [[-0.6317, -0.1134,  0.0739],
          [ 0.5791,  0.2177,  0.0511],
          [ 0.2978, -0.0355, -0.1278]]],


        [[[ 0.5702,  0.0161, -0.1847],
          [-0.6763, -0.0492,  0.2023],
          [-0.6830, -0.1313,  0.1590]],

         [[-1.3637,  0.3557, -0.7735],
          [ 0.6436, -0.3130,  0.3675],
          [-1.4432,  0.9831, -0.8288]]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[[ 0.0353, -0.0741, -0.1236],
          [ 0.1130, -0.1832, -0.3139],
          [ 0.0325, -0.0574, -0.0975]],

         [[-0.4467, -0.0802,  0.0522],
          [ 0.4095,  0.1539,  0.0361],
          [ 0.2106, -0.0251, -0.0904]]],


        [[[ 0.4032,  0.0114, -0.1306],
          [-0.4782, -0.0348,  0.1431],
          [-0.4830, -0.0929,  0.1124]],

         [[-0.9643,  0.2515, -0.5470],
          [ 0.4551, -0.2213,  0.2599],
          [-1.0205,  0.6951, -0.5861]]]], grad_fn=<Di