In [6]:
import math
import torch
import torch.nn as nn


class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X):
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        attention_value = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)
        attention_weight = torch.softmax(attention_value, dim=-1)
        print(attention_weight)
        output = torch.matmul(attention_weight, V)

        return output


X = torch.rand(3, 2, 4)
net = SelfAttentionV1(4)
my_output = net(X)

print(my_output)


tensor([[[0.5180, 0.4820],
         [0.4755, 0.5245]],

        [[0.4997, 0.5003],
         [0.4944, 0.5056]],

        [[0.4919, 0.5081],
         [0.4977, 0.5023]]], grad_fn=<SoftmaxBackward0>)
tensor([[[-0.2245,  0.1684, -0.4920, -0.3165],
         [-0.2429,  0.1510, -0.4933, -0.2871]],

        [[-0.1365,  0.1325, -0.4653, -0.3849],
         [-0.1353,  0.1325, -0.4645, -0.3861]],

        [[-0.2538,  0.0532, -0.4771, -0.1982],
         [-0.2528,  0.0526, -0.4754, -0.1988]]], grad_fn=<UnsafeViewBackward0>)


In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim: int, head_num: int, dropout_rate: float) -> None:
        super().__init__()
        self.dim = dim
        self.head_num = head_num
        self.dropout = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(dim, 3 * dim)
        self.head_dim = dim // head_num

    def forward(self, X: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = X.shape
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, -1)
        # (batch_size, seq_len, dim) -> (batch_size, head_num, seq_len, head_dim)
        Q = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)

        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.head_dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float("-inf")
            )
        attention_weight = torch.softmax(attention_weight, -1)
        # print(attention_weight)
        attention_weight = self.dropout(attention_weight)
        output_mid = attention_weight @ V
        output_mid = output_mid.transpose(1, 2).contiguous()
        print(output_mid.shape)

        return output_mid.view(batch_size, seq_len, -1)


X = torch.rand(3, 2, 128)
net = MultiHeadAttention(128, 8, 0.1)
mask = torch.tensor([[1, 0], [0, 1], [1, 1]])
# (batch_size, seq_len) -> (batch_size, head_num, seq_len, seq_len)
mask = mask.unsqueeze(1).unsqueeze(1).expand(-1, 8, 2, 2)
print(mask.shape)
output = net(X, mask)
print(output.shape)


torch.Size([3, 8, 2, 2])
torch.Size([3, 2, 8, 16])
torch.Size([3, 2, 128])
