# Transformer 中的多头注意力

In [16]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [17]:
X = torch.randn(128, 64, 512) # (batch_size, seq_len, dimension)
print(X.shape)

torch.Size([128, 64, 512])


In [18]:
d_model = 512
n_head = 8

In [19]:
atten_score = torch.randn(4, 4)
print(atten_score)
mask = torch.tril(torch.ones(4, 4, dtype=bool))
atten_score = atten_score.masked_fill_(mask == 0, -math.inf)
print(mask)
print(atten_score)

tensor([[-1.5881,  0.2290,  0.2096,  0.7041],
        [-0.4473,  0.2256, -1.2405, -0.1582],
        [-0.8031,  0.4126, -1.2328,  1.0075],
        [-0.7545, -0.9136,  0.6013, -1.2179]])
tensor([[ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]])
tensor([[-1.5881,    -inf,    -inf,    -inf],
        [-0.4473,  0.2256,    -inf,    -inf],
        [-0.8031,  0.4126, -1.2328,    -inf],
        [-0.7545, -0.9136,  0.6013, -1.2179]])


In [21]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head) -> None:
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_k = d_model // n_head
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, q, k, v):
        batch_size, seq_len, dimension = q.shape
        n_d = dimension // self.n_head
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q = q.view(batch_size, seq_len, self.n_head, n_d).permute(0, 2, 1, 3)
        k = k.view(batch_size, seq_len, self.n_head, n_d).permute(0, 2, 1, 3)
        v = v.view(batch_size, seq_len, self.n_head, n_d).permute(0, 2, 1, 3)

        attn_score = q @ k.transpose(2, 3) / math.sqrt(n_d)
        mask = torch.tril(torch.ones(seq_len, seq_len, dtype=bool))
        attn_score = attn_score.masked_fill(mask == 0, float('-inf'))
        attn_score = self.softmax(attn_score) @ v

        atten_score = attn_score.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, dimension)

        output = self.w_combine(atten_score)
        return output
    
multi_head_atten = MultiHeadAttention(d_model, n_head)
output = multi_head_atten(X, X, X)
print(output.shape)
print(output)


torch.Size([128, 64, 512])
tensor([[[-0.7563,  0.2732, -0.1377,  ...,  0.6893, -0.0595, -0.0735],
         [-0.5859,  0.0887, -0.0161,  ...,  0.2119, -0.0728,  0.2601],
         [-0.3679, -0.2490,  0.0336,  ...,  0.1061,  0.1394,  0.2043],
         ...,
         [ 0.0496, -0.0639,  0.0130,  ...,  0.0110, -0.0519, -0.0672],
         [ 0.0323, -0.0598,  0.0200,  ...,  0.0174, -0.0532, -0.0883],
         [ 0.0031, -0.0553,  0.0055,  ...,  0.0079, -0.0467, -0.0916]],

        [[ 0.2163, -0.2435,  0.1143,  ..., -0.0651, -0.4114, -0.1521],
         [ 0.0284, -0.1171,  0.2203,  ...,  0.1296, -0.1959,  0.0220],
         [-0.0879, -0.1645,  0.1377,  ...,  0.1798, -0.0933, -0.0904],
         ...,
         [ 0.0325, -0.0391,  0.0850,  ...,  0.0449, -0.1511,  0.0230],
         [ 0.0219, -0.0243,  0.0599,  ...,  0.0269, -0.1197,  0.0151],
         [-0.0336, -0.0580,  0.0721,  ...,  0.0541, -0.1078, -0.0214]],

        [[ 0.1049, -0.1835, -0.6369,  ..., -0.2131, -0.0796, -0.0694],
         [ 0.1828,