In [6]:
import math
import torch
import torch.nn as nn


class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X):
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        attention_value = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)
        attention_weight = torch.softmax(attention_value, dim=-1)
        print(attention_weight)
        output = torch.matmul(attention_weight, V)

        return output


X = torch.rand(3, 2, 4)
net = SelfAttentionV1(4)
my_output = net(X)

print(my_output)


tensor([[[0.5180, 0.4820],
         [0.4755, 0.5245]],

        [[0.4997, 0.5003],
         [0.4944, 0.5056]],

        [[0.4919, 0.5081],
         [0.4977, 0.5023]]], grad_fn=<SoftmaxBackward0>)
tensor([[[-0.2245,  0.1684, -0.4920, -0.3165],
         [-0.2429,  0.1510, -0.4933, -0.2871]],

        [[-0.1365,  0.1325, -0.4653, -0.3849],
         [-0.1353,  0.1325, -0.4645, -0.3861]],

        [[-0.2538,  0.0532, -0.4771, -0.1982],
         [-0.2528,  0.0526, -0.4754, -0.1988]]], grad_fn=<UnsafeViewBackward0>)


In [1]:
import math
import torch
import torch.nn as nn


class MySelfAttentionV2(nn.Module):
    def __init__(self, hidden_dim: int) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        self.proj = nn.Linear(hidden_dim, hidden_dim * 3)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # X (batch_size, seq_len, hidden_dim)
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.hidden_dim, -1)
        attention_value = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)
        attention_weight = torch.softmax(attention_value, -1)
        print(attention_weight)
        output = self.output_proj(attention_weight @ V)
        return output


X = torch.rand(3, 2, 4)
net = MySelfAttentionV2(4)
output = net(X)
print(output)

tensor([[[0.4728, 0.5272],
         [0.4700, 0.5300]],

        [[0.5052, 0.4948],
         [0.4900, 0.5100]],

        [[0.5311, 0.4689],
         [0.5651, 0.4349]]], grad_fn=<SoftmaxBackward0>)
tensor([[[-0.0595,  0.3484, -0.2286,  0.1935],
         [-0.0587,  0.3479, -0.2298,  0.1930]],

        [[-0.1897,  0.3886, -0.0556,  0.2985],
         [-0.1905,  0.3861, -0.0509,  0.2971]],

        [[-0.1905,  0.4036, -0.0889,  0.3350],
         [-0.1808,  0.3952, -0.0947,  0.3235]]], grad_fn=<ViewBackward0>)


In [2]:
class MySelfAttentionV3(nn.Module):
    def __init__(self, dim: int, dropout_rate: float = 0.1) -> None:
        super().__init__()
        self.dim = dim
        self.attention_dropout = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(dim, dim * 3)
        self.output_proj = nn.Linear(dim, dim)

    def forward(
        self, X: torch.Tensor, attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, -1)
        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float("-inf")
            )
        attention_weight = torch.softmax(attention_weight, -1)
        print(attention_weight.shape)

        attention_weight = self.attention_dropout(attention_weight)
        return self.output_proj(attention_weight @ V)


X = torch.rand(3, 4, 2)
mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 0, 0, 0]])
mask = mask.unsqueeze(1).repeat(1, 4, 1)
net = MySelfAttentionV3(2)
output = net(X, mask)
print(output)

torch.Size([3, 4, 4])
tensor([[[ 0.0316, -0.5146],
         [ 0.1702, -0.2773],
         [ 0.0308, -0.5160],
         [ 0.2350, -0.2210]],

        [[ 0.1761, -0.4255],
         [ 0.3927, -0.0598],
         [ 0.1761, -0.4256],
         [ 0.3878, -0.0675]],

        [[ 0.1601, -0.3420],
         [ 0.5714,  0.2222],
         [ 0.1601, -0.3420],
         [ 0.1601, -0.3420]]], grad_fn=<ViewBackward0>)


In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim: int, head_num: int, dropout_rate: float) -> None:
        super().__init__()
        self.dim = dim
        self.head_num = head_num
        self.dropout = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(dim, 3 * dim)
        self.head_dim = dim // head_num

    def forward(self, X: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = X.shape
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, -1)
        # (batch_size, seq_len, dim) -> (batch_size, head_num, seq_len, head_dim)
        Q = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)

        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.head_dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float("-inf")
            )
        attention_weight = torch.softmax(attention_weight, -1)
        # print(attention_weight)
        attention_weight = self.dropout(attention_weight)
        output_mid = attention_weight @ V
        output_mid = output_mid.transpose(1, 2).contiguous()
        print(output_mid.shape)

        return output_mid.view(batch_size, seq_len, -1)


X = torch.rand(3, 2, 128)
net = MultiHeadAttention(128, 8, 0.1)
mask = torch.tensor([[1, 0], [0, 1], [1, 1]])
# (batch_size, seq_len) -> (batch_size, head_num, seq_len, seq_len)
mask = mask.unsqueeze(1).unsqueeze(1).expand(-1, 8, 2, 2)
print(mask.shape)
output = net(X, mask)
print(output.shape)


torch.Size([3, 8, 2, 2])
torch.Size([3, 2, 8, 16])
torch.Size([3, 2, 128])


In [4]:
import math
import torch
import torch.nn as nn


class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_dim: int, head_num: int, key_value_num: int):
        super().__init__()
        assert hidden_dim % head_num == 0
        assert head_num % key_value_num == 0

        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.key_value_num = key_value_num
        self.head_dim = hidden_dim // head_num

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, key_value_num * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, key_value_num * self.head_dim)

        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = X.size()
        q = self.q_proj(X)
        k = self.k_proj(X)
        v = self.v_proj(X)

        # (batch_size, seq_len, hidden_dim) -> (batch_size, head_num, seq_len, head_dim)
        q = q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        # -> (batch_size, key_value_num, seq_len, head_dim)
        k = k.view(batch_size, seq_len, self.key_value_num, self.head_dim).transpose(
            1, 2
        )
        v = v.view(batch_size, seq_len, self.key_value_num, self.head_dim).transpose(
            1, 2
        )
        # 扩充 k v
        k = k.repeat_interleave(self.head_num // self.key_value_num, 1)
        v = v.repeat_interleave(self.head_num // self.key_value_num, 1)

        # (batch_size, head_num, seq_len, seq_len)
        attention_weight = q @ k.transpose(2, 3) / math.sqrt(self.head_dim)
        attention_weight = torch.softmax(attention_weight, -1)
        output_mid = attention_weight @ v  # (batch_size, head_num, seq_len, head_dim)
        output_mid = (
            output_mid.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        )

        return self.o_proj(output_mid)


X = torch.rand(3, 2, 128)
net = GroupQueryAttention(128, 8, 4)
output = net(X)
print(output.shape)

torch.Size([3, 2, 128])
