In [4]:
import torch
import torch.nn as nn
import math

# 注意力机制实现

## self-attention实现
> 参考: https://www.bilibili.com/video/BV19YbFeHETz?spm_id_from=333.788.videopod.sections&vd_source=b419802666550a8f77628730aa29c06b

`Attention`计算公式:
$$
Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V
$$

### self-attention实现(第一层)
> 只考虑公式

In [21]:
# self-attention 实现
class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int = 728) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        # Q,K,V 权重矩阵 hidden_dim*hidden_dim
        self.query_proj = nn.Linear(hidden_dim,hidden_dim, bias=False)
        self.key_proj = nn.Linear(hidden_dim,hidden_dim, bias=False)
        self.value_proj = nn.Linear(hidden_dim,hidden_dim, bias=False)
    def forward(self, X):
        # X.shape -> (batch_size, seq, dim)
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)
        # Q,K,V -> (batch_size, seq, dim)
        # 计算注意力评分函数
        attention_value = torch.bmm(Q, K.permute(0, 2, 1))
        # 计算 softmax
        # attention_weights.shape -> (batch_size, seq, seq)
        self._attention_weights = torch.softmax(attention_value/math.sqrt(self.hidden_dim), dim=-1)
        Y = torch.bmm(self._attention_weights, V)
        return Y

    @property
    def attention_weights(self):
        return self._attention_weights

In [26]:
# 测试 self_attention
X = torch.rand((3, 2, 4))
self_attr_net1 = SelfAttentionV1(hidden_dim=4)
Y = self_attr_net1(X)
print('attention_weights = \t', self_attr_net1.attention_weights)
print('Y = \t', Y)
print('X.shape: ', X.shape, '\t Y.shape: ', Y.shape, '\tattention_weights: ', self_attr_net1.attention_weights.shape)

attention_weights = 	 tensor([[[0.4867, 0.5133],
         [0.4865, 0.5135]],

        [[0.4910, 0.5090],
         [0.4893, 0.5107]],

        [[0.4978, 0.5022],
         [0.5054, 0.4946]]], grad_fn=<SoftmaxBackward0>)
Y = 	 tensor([[[-0.0028, -0.1970,  0.2460,  0.4847],
         [-0.0028, -0.1970,  0.2461,  0.4847]],

        [[ 0.2811, -0.1148, -0.0239,  0.1229],
         [ 0.2816, -0.1148, -0.0239,  0.1228]],

        [[ 0.1856, -0.1672,  0.2208,  0.2608],
         [ 0.1864, -0.1670,  0.2212,  0.2580]]], grad_fn=<BmmBackward0>)
X.shape:  torch.Size([3, 2, 4]) 	 Y.shape:  torch.Size([3, 2, 4]) 	attention_weights:  torch.Size([3, 2, 2])


### Self-Attention实现(第二层)
> QKV矩阵优化, 使用一个矩阵表示

In [39]:
class SelfAttentionV2(nn.Module):
    def __init__(self, dim: int = 512) -> None:
        super().__init__()
        self.dim = dim
        self.proj = nn.Linear(dim, dim*3)
    def forward(self, X):
        # X.shape -> (batch_size, seq, dim)
        QKV = self.proj(X)
        # QKV.shape -> (batch_size, seq, dim*3)
        Q, K, V = torch.split(QKV, self.dim, dim=2)
        # Q,K,V.shape -> (batch_size, seq, dim*3)
        attention_value = torch.bmm(Q,K.permute(0, 2, 1)) / math.sqrt(self.dim)
        # attention_value.shape -> (batch_size, seq, seq)
        self._attention_weights = torch.softmax(attention_value,dim=-1)
        Y = torch.bmm(self._attention_weights, V)
        return Y
    @property
    def attention_weights(self):
        return self._attention_weights

In [41]:
X = torch.rand(3, 2, 4)
self_attr_net = SelfAttentionV2(4)
Y = self_attr_net(X)
print('X = ', X)
print('attention_weights = ', self_attr_net.attention_weights)
print('Y = ', Y)

X =  tensor([[[0.8800, 0.3124, 0.4686, 0.3256],
         [0.1465, 0.7232, 0.9188, 0.6566]],

        [[0.0145, 0.5094, 0.8818, 0.4162],
         [0.1500, 0.8649, 0.2281, 0.2649]],

        [[0.2055, 0.6865, 0.6145, 0.5852],
         [0.5847, 0.8555, 0.4184, 0.1247]]])
attention_weights =  tensor([[[0.5123, 0.4877],
         [0.5264, 0.4736]],

        [[0.4727, 0.5273],
         [0.4830, 0.5170]],

        [[0.4786, 0.5214],
         [0.4957, 0.5043]]], grad_fn=<SoftmaxBackward0>)
Y =  tensor([[[ 0.3297,  0.2641, -0.2786, -0.5755],
         [ 0.3288,  0.2607, -0.2785, -0.5748]],

        [[ 0.2304,  0.3561, -0.3570, -0.4942],
         [ 0.2295,  0.3578, -0.3562, -0.4936]],

        [[ 0.3708,  0.2249, -0.3171, -0.5494],
         [ 0.3687,  0.2293, -0.3175, -0.5511]]], grad_fn=<BmmBackward0>)


### Self-Attention实现(第三层)
优化内容如下:
1. 加入 `Dropout` 层(`Softmax`之后)
2. 加入 `attention_mask`
3. 加入 `output`矩阵

In [52]:
class SelfAttentionV3(nn.Module):
    def __init__(self, dim: int=512, dropout_rate:float =0.1) -> None:
        super().__init__()
        self.dim = dim
        self.proj = nn.Linear(dim, dim*3)
        self.dropout = nn.Dropout(dropout_rate)
        self.output_proj = nn.Linear(dim, dim)
    def forward(self, X, attention_mask=None):
        # X.shape -> (batch_size, seq, dim)
        QKV = self.proj(X)
        Q,K,V = torch.split(QKV, self.dim, dim=-1)
        # 计算 attention_weights
        attention_value = Q @ K.permute(0, 2, 1) / math.sqrt(self.dim)
        # mask 操作
        if attention_mask is not None:
            attention_value = attention_value.masked_fill(
                attention_mask == 0,
                float("-1e20")
            )
        attention_weights = torch.softmax(attention_value,dim=-1)
        # 注意 nn.Dropout 的位置
        attention_weights = self.dropout(attention_weights)
        output = attention_weights @ V
        output = self.output_proj(output)
        return output

In [53]:
X = torch.rand(3, 4, 2)
attention_mask = torch.tensor([
    [1, 1, 1, 0],
    [1, 1, 0, 0],
    [1, 0, 0, 0],
])
# 注意 mask 需要和 attention_value 的形状一样 -> (batch_size, seq, seq)
# (batch_size, seq) -> (batch_size, seq, seq)
attention_mask = attention_mask.unsqueeze(1).repeat(1, 4, 1)
self_attr_net = SelfAttentionV3(dim=2)
Y = self_attr_net(X, attention_mask)
print('Y.shape = ', Y.shape)

Y.shape =  torch.Size([3, 4, 2])


### Self-Attention面试写法

In [54]:
class SelfAttentionInterView(nn.Module):
    def __init__(self, dim: int = 512, dropout_rate: float = 0.1) -> None:
        super().__init__()
        self.dim = dim
        # Q,K,V 矩阵
        self.query_proj = nn.Linear(dim, dim)
        self.key_proj = nn.Linear(dim, dim)
        self.value_proj = nn.Linear(dim, dim)
        # Dropout 层
        self.attention_dropout = nn.Dropout(dropout_rate)
        # Output 映射
        self.output_proj = nn.Linear(dim, dim)
        self._attention_weights = None
    def forward(self, X, attention_mask=None):
        # X.shape -> (batch_size, seq, dim)
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)
        # Q,K,V shape -> (batch_size, seq, dim)
        attention_value = Q @ K.permute(0, 2, 1) / math.sqrt(self.dim)
        # attention_value.shape -> (batch_size, seq, seq)
        # mask 操作
        if attention_mask is not None:
            attention_value = attention_value.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
        # softmax 操作
        self._attention_weights = self.attention_dropout(
            torch.softmax(attention_value, dim=-1)
        )
        # 计算 output 
        output = self._attention_weights @ V
        output = self.output_proj(output)
        return output
    @property
    def attention_weights(self):
        return self._attention_weights


In [56]:
X = torch.rand(3, 4, 2)
mask = torch.tensor([
    [1, 1, 1, 0],
    [1, 1, 0, 0],
    [1, 0, 0, 0],
])
mask = mask.unsqueeze(1).repeat(1, 4, 1)
self_attr_net = SelfAttentionInterView(2)
Y = self_attr_net(X, mask)
print('Y = ', Y)
print('Y.shape = ', Y.shape)
print('attention_weights = ', self_attr_net.attention_weights)

Y =  tensor([[[-0.7204, -0.2448],
         [-0.8157,  0.1936],
         [-0.8218, -0.0176],
         [-0.8137,  0.1958]],

        [[-0.7438,  0.0593],
         [-0.7438,  0.0593],
         [-0.6876, -0.1786],
         [-0.7438,  0.0593]],

        [[-0.9652,  0.1250],
         [-0.9652,  0.1250],
         [-0.9652,  0.1250],
         [-0.9652,  0.1250]]], grad_fn=<ViewBackward0>)
Y.shape =  torch.Size([3, 4, 2])
attention_weights =  tensor([[[0.0000, 0.3630, 0.0000, 0.0000],
         [0.3486, 0.3919, 0.3706, 0.0000],
         [0.0000, 0.3886, 0.3748, 0.0000],
         [0.3516, 0.3676, 0.3919, 0.0000]],

        [[0.5519, 0.5592, 0.0000, 0.0000],
         [0.5524, 0.5587, 0.0000, 0.0000],
         [0.0000, 0.5596, 0.0000, 0.0000],
         [0.5521, 0.5590, 0.0000, 0.0000]],

        [[1.1111, 0.0000, 0.0000, 0.0000],
         [1.1111, 0.0000, 0.0000, 0.0000],
         [1.1111, 0.0000, 0.0000, 0.0000],
         [1.1111, 0.0000, 0.0000, 0.0000]]], grad_fn=<MulBackward0>)


## Multi-Head Attention 实现
多头注意力结构如下:
![img1](img/2025-07-23_18-12.png)

In [66]:
# 多头注意力机制实现
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout_rate=0.1):
        super().__init__()
        self.head_dim = hidden_dim // num_heads
        self.query_proj = nn.Linear(hidden_dim, hidden_dim) # 这里其实是 num_heads * head_dim
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim) 
        self.attention_dropout = nn.Dropout(dropout_rate)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
    def forward(self, X, attention_mask=None):
        # X.shape -> (batch_size, seq, hidden_dim)
        batch_size, seq_len, _ = X.shape
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)
        # Q,K,V.shape -> (batch_size, seq, hidden_dim)
        # Q,K,V 分层 -> (batch_size, num_heads, seq, head_dim)
        Q = Q.reshape(batch_size, seq_len, -1, self.head_dim).permute(0, 2, 1, 3)
        K = K.reshape(batch_size, seq_len, -1, self.head_dim).permute(0, 2, 1, 3)
        V = V.reshape(batch_size, seq_len, -1, self.head_dim).permute(0, 2, 1, 3)
        # Q,K,V.shape -> (batch_size, num_heads, seq, head_dim)
        attention_value = Q @ K.permute(0, 1, 3, 2) / math.sqrt(self.head_dim)
        # attention_value 形状 -> (batch_size, num_heads, seq, seq)
        if attention_mask is not None:
            attention_value = attention_value.masked_fill(
                attention_mask == 0,
                float("-1e20")
            )
        # softmax 操作 + Dropout 操作
        attention_weights = torch.softmax(attention_value, dim=-1)
        print('attention_weights.shape = ', attention_weights.shape)
        attention_weights = self.attention_dropout(attention_weights)
        # 输出数据
        # attention_weights -> (batch_size, num_heads, seq, seq)
        # V                 -> (batch_size, num_heads, seq, num_hiddens)
        # (batch_size, num_heads, seq, num_hiddens) -> (batch_size, seq, num_hiddens)
        output = (attention_weights @ V).permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1) 
        output = self.output_proj(output)
        return output

In [67]:
# 测试多头注意力机制
X = torch.rand(3, 2, 128)
# (2, 2) -> (3, 8, 2, 2)
mask = torch.tensor([
    [1, 1],
    [1, 0]
])
mask = mask.unsqueeze(0).unsqueeze(1).repeat(3, 8, 1, 1)
multi_attention = MultiHeadAttention(128, 8)
Y = multi_attention(X)
Y.shape

attention_weights.shape =  torch.Size([3, 8, 2, 2])


torch.Size([3, 2, 128])

## 总结
无论是对于 `Self-Attention`还是`Multi-Head Attention`, 注意可以采用的优化策略都有:
1. 使用一个大矩阵来包含变换矩阵
2. 注意矩阵变换过程中的维度变换(相邻的维度可以合并, 同时一个维度可以拆分为相邻的两个维度的和, 使用 `view`变换需要操作连续内存空间, 但是使用 `reshape`不需要, 使用 `transpose`进行矩阵两个维度的交换, 使用 `permute` 进行维度的变换)