In [8]:
import torch
import torch.nn as nn
import math

输入x的张量的大小：
batch_size = 16
sequence_length = 64
dimension = 512

In [9]:
x = torch.randn(16,64,512)

In [10]:
x.shape

torch.Size([16, 64, 512])

In [11]:
d_model = 512
n_head = 8

In [14]:
class multi_head_attention(nn.Module):
    def __init__(self, d_model, n_head):
        super(multi_head_attention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head

        # 定义Q,K,V权重
        self.w_q = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.w_k = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.w_v = nn.Linear(in_features=self.d_model, out_features=self.d_model)

        self.softmax = nn.Softmax(dim = -1)

        self.w_o = nn.Linear(in_features=self.d_model, out_features=self.d_model)

    def forward(self, x):
        # 获取输入查询（q），键（k），值（v）的维度
        B, T, D = x.shape   # b=16, t=64, d=512

        # 获取每个注意力头的维度
        n_d = self.d_model // self.n_head

        q = self.w_q(x)     # (b, t, d) * (d, d) -> (b, t, d)
        k = self.w_k(x)     # (b, t, d) * (d, d) -> (b, t, d)
        v = self.w_v(x)     # (b, t, d) * (d, d) -> (b, t, d)

        q = q.view(B, T, self.n_head, n_d).transpose(1,2)   # (b, t, d) ->(b, n_head, t, n_d)
        k = k.view(B, T, self.n_head, n_d).transpose(1,2)   # (b, t, d) ->(b, n_head, t, n_d)
        v = v.view(B, T, self.n_head, n_d).transpose(1,2)   # (b, t, d) ->(b, n_head, t, n_d)

        # 缩放点积注意力
        score = q @ k.transpose(2,3) / math.sqrt(n_d)  #(b, n_head, t, t)

        # 进行三角mask,形成一个下三角矩阵屏蔽未来的信息
        mask = torch.tril(torch.ones(T, T, dtype=bool))
        score = score.masked_fill(mask == 0, -10000)

        # softmax激活函数
        score = self.softmax(score)

        # 和v值向量进行点积
        score = score @ v   # (b, n_head, t, n_d)

        # 将多个头的结果concate，并通过线性层映射
        x_concate = score.transpose(1,2).contiguous().view(B, T, self.d_model)   # (b, t, d_model)
        x_out = self.w_o(x_concate)   # (b, t, d_model)

        return x_out
        
        
        
        

In [15]:
mha = multi_head_attention(d_model, n_head)
out = mha(x)
print(out.shape)

torch.Size([16, 64, 512])
