# 1. 简单自注意力

In [4]:
import torch
import torch.nn.functional as F

x = torch.randn(2, 3, 4)
raw_weights = torch.bmm(x, x.transpose(1, 2))
attn_weights = F.softmax(raw_weights, dim=2)
attn_outputs = torch.bmm(attn_weights, x)
print(attn_outputs)

tensor([[[-0.2774,  1.3247, -0.6158, -1.9739],
         [-0.9119,  1.0185, -0.8377,  0.3188],
         [-1.0841,  0.8404, -1.2366,  0.8150]],

        [[-0.1870,  0.0952,  1.3785,  0.0742],
         [-0.3248, -0.1043,  0.0555, -0.7568],
         [-0.3077, -0.1424,  0.2153, -0.6645]]])


# 2. 标准自注意力

In [6]:
x = torch.randn(2, 3, 4)
linear_q = torch.nn.Linear(4, 4)
linear_k = torch.nn.Linear(4, 4)
linear_v = torch.nn.Linear(4, 4)

Q = linear_q(x)
K = linear_k(x)
V = linear_v(x)

raw_weights = torch.bmm(Q, K.transpose(1, 2))
print(raw_weights)

scale_factor = K.size(-1)**0.5
scale_weights = raw_weights / scale_factor
print(scale_weights)

attn_weights = F.softmax(scale_weights, dim=2)
attn_outputs = torch.bmm(attn_weights, V)
print(f"加权自注意力: {attn_outputs}")

tensor([[[ 4.4115e-02, -4.3930e-01, -6.8740e-01],
         [-3.9619e-01,  3.0652e-01, -1.3978e+00],
         [ 1.4006e-01, -5.7866e-01, -3.7079e-01]],

        [[ 1.1333e-02, -5.1178e-01, -4.0106e-01],
         [ 1.9789e-03, -7.4621e-01, -2.6430e+00],
         [-3.7205e-02,  4.7986e-01, -8.4291e-01]]], grad_fn=<BmmBackward0>)
tensor([[[ 2.2057e-02, -2.1965e-01, -3.4370e-01],
         [-1.9809e-01,  1.5326e-01, -6.9889e-01],
         [ 7.0030e-02, -2.8933e-01, -1.8539e-01]],

        [[ 5.6663e-03, -2.5589e-01, -2.0053e-01],
         [ 9.8946e-04, -3.7310e-01, -1.3215e+00],
         [-1.8603e-02,  2.3993e-01, -4.2145e-01]]], grad_fn=<DivBackward0>)
加权自注意力: tensor([[[ 0.4822,  0.0936, -0.8211, -0.0261],
         [ 0.6444, -0.1040, -1.0462,  0.3048],
         [ 0.4439,  0.1453, -0.7726, -0.1027]],

        [[ 0.5080,  0.4865, -0.5447, -0.6990],
         [ 0.6557,  0.5236, -0.4017, -0.7416],
         [ 0.6158,  0.5346, -0.5460, -0.8147]]], grad_fn=<BmmBackward0>)


# 3. 多头自注意力

In [8]:
import torch
import torch.nn.functional as F

x = torch.randn(2, 3, 4)

num_heads = 2
head_dim = 2

assert x.size(-1) == num_heads * head_dim

linear_q = torch.nn.Linear(4, 4)
linear_k = torch.nn.Linear(4, 4)
linear_v = torch.nn.Linear(4, 4)

Q = linear_q(x)
K = linear_k(x)
V = linear_v(x)

def split_heads(tensor, num_heads):
    batch_size, seq_len, feature_dim = tensor.size()
    head_dim = feature_dim // num_heads
    output = tensor.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
    return output

Q = split_heads(Q, num_heads)
K = split_heads(K, num_heads)
V = split_heads(V, num_heads)

raw_weights = torch.matmul(Q, K.transpose(-2, -1))
print(raw_weights)

scale_factor = K.size(-1)**0.5
scale_weights = raw_weights / scale_factor
print(scale_weights)

attn_weights = F.softmax(scale_weights, dim=-1)
attn_outputs = torch.matmul(attn_weights, V)

def combine_heads(tensor):
    batch_size, num_heads, seq_len, head_dim = tensor.size()
    feature_dim = num_heads * head_dim
    output = tensor.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)
    return output

attn_outputs = combine_heads(attn_outputs)

linear_out = torch.nn.Linear(4, 4)
attn_outputs = linear_out(attn_outputs)

print(f"多头自注意力: {attn_outputs}")



tensor([[[[-0.0835,  0.0063, -0.0904],
          [ 0.0716, -0.3325,  0.0660],
          [-0.1550,  0.3327, -0.1564]],

         [[-0.0286, -0.1234, -0.1101],
          [ 0.0554, -1.0449, -0.1753],
          [-0.0283, -0.0818, -0.0967]]],


        [[[-0.0226,  0.0611, -0.0630],
          [-0.5410, -0.5744, -0.0994],
          [-0.3347, -0.3095, -0.0932]],

         [[-0.0369, -0.1251, -0.1534],
          [ 0.0999,  0.3243,  0.3613],
          [ 0.1801,  0.4295,  0.0500]]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[[-0.0591,  0.0045, -0.0639],
          [ 0.0506, -0.2351,  0.0466],
          [-0.1096,  0.2353, -0.1106]],

         [[-0.0202, -0.0873, -0.0779],
          [ 0.0392, -0.7388, -0.1239],
          [-0.0200, -0.0578, -0.0684]]],


        [[[-0.0160,  0.0432, -0.0446],
          [-0.3825, -0.4062, -0.0703],
          [-0.2367, -0.2189, -0.0659]],

         [[-0.0261, -0.0884, -0.1085],
          [ 0.0706,  0.2293,  0.2555],
          [ 0.1274,  0.3037,  0.0354]]]], grad_fn=<Di