In [5]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [6]:
x = torch.rand(128, 32, 512)
d_model = 512
n_head = 8

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()

        self.d_model = d_model
        self.n_head = n_head
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        # 存储每个头的维度
        self.d_k = d_model // n_head

    def forward(self, q, k, v, mask=None):
        batch, time, dimension = q.shape
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        # reshape q, k, v
        q = q.view(batch, time, self.n_head, self.d_k).permute(0, 2, 1, 3)
        k = k.view(batch, time, self.n_head, self.d_k).permute(0, 2, 1, 3)
        v = v.view(batch, time, self.n_head, self.d_k).permute(0, 2, 1, 3)
        # score
        score = q@k.transpose(2, 3)/math.sqrt(self.d_k)
        if mask is not None:
            score = score.masked_fill(mask==0, -10000)
        score = self.softmax(score)@v
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, time, dimension)
        outputs = self.w_combine(score)
        return outputs
    
attention = MultiHeadAttention(d_model, n_head)

In [8]:
out = attention(x, x, x)
print(out)

tensor([[[ 0.2157, -0.0139, -0.1372,  ..., -0.1058, -0.0572,  0.0882],
         [ 0.2146, -0.0134, -0.1360,  ..., -0.1048, -0.0578,  0.0881],
         [ 0.2152, -0.0134, -0.1359,  ..., -0.1054, -0.0571,  0.0874],
         ...,
         [ 0.2150, -0.0137, -0.1360,  ..., -0.1053, -0.0575,  0.0872],
         [ 0.2157, -0.0142, -0.1361,  ..., -0.1046, -0.0571,  0.0880],
         [ 0.2150, -0.0141, -0.1362,  ..., -0.1054, -0.0565,  0.0887]],

        [[ 0.2465, -0.0294, -0.1742,  ..., -0.1299, -0.0583,  0.1152],
         [ 0.2477, -0.0278, -0.1743,  ..., -0.1299, -0.0574,  0.1148],
         [ 0.2480, -0.0280, -0.1736,  ..., -0.1305, -0.0583,  0.1155],
         ...,
         [ 0.2471, -0.0285, -0.1741,  ..., -0.1303, -0.0579,  0.1143],
         [ 0.2475, -0.0285, -0.1744,  ..., -0.1312, -0.0586,  0.1151],
         [ 0.2477, -0.0278, -0.1732,  ..., -0.1299, -0.0578,  0.1152]],

        [[ 0.2451,  0.0038, -0.1495,  ..., -0.0586, -0.0730,  0.0898],
         [ 0.2451,  0.0030, -0.1499,  ..., -0