# Self Attention

## 一、简化版本

In [None]:
# 导入相关需要的包
import math
import torch
import torch.nn as nn
from thop import profile

import warnings
warnings.filterwarnings(action="ignore")


class SelfAttV1(nn.Module):
    def __init__(self, hidden_dim):
        super(SelfAttV1, self).__init__()
        self.hidden_dim = hidden_dim
        # 一般 Linear 都是默认有 bias
        # 一般来说， input dim 的 hidden dim
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X):
        # X shape is: (batch, seq_len, hidden_dim)， 一般是和 hidden_dim 相同
        # 但是 X 的 final dim 可以和 hidden_dim 不同
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        # shape is: (batch, seq_len, seq_len)
        # torch.matmul 可以改成 Q @ K.T
        # 其中 K 需要改成 shape 为： (batch, hidden_dim, seq_len)
        # attention_value = torch.matmul(Q, K.transpose(-1, -2))
        attention_value = Q @ K.transpose(-1, -2)
        # print(attention_value, attention_value.shape)
        attention_wight = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim), dim=-1
        )
        print("attention_weight:", attention_wight)
        # shape is: (batch, seq_len, hidden_dim)
        output = attention_wight @ V
        return output


X = torch.rand(3, 2, 4)
net = SelfAttV1(4)
net(X)
# 计算 SelfAttV1 的 GFLOPs
flops_v1, params_v1 = profile(net, inputs=(X,))
print(f"SelfAttV1 FLOPs: {flops_v1 / 1e6:.6f} MFLOPs")
print(f"SelfAttV1 Parameters: {params_v1 / 1e6:.6f} M")

attention_weight: tensor([[[0.4809, 0.5191],
         [0.4735, 0.5265]],

        [[0.5140, 0.4860],
         [0.5096, 0.4904]],

        [[0.5013, 0.4987],
         [0.5014, 0.4986]]], grad_fn=<SoftmaxBackward0>)
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
attention_weight: tensor([[[0.4809, 0.5191],
         [0.4735, 0.5265]],

        [[0.5140, 0.4860],
         [0.5096, 0.4904]],

        [[0.5013, 0.4987],
         [0.5014, 0.4986]]])
SelfAttV1 FLOPs: 0.000288 MFLOPs
SelfAttV1 Parameters: 0.000060 M


- 参数量：

每个线性层 nn.Linear(hidden_dim, hidden_dim) 有 hidden_dim * hidden_dim + hidden_dim 个参数（权重和偏置）。
总共有三个这样的线性层，因此总参数量为 3 * (hidden_dim * hidden_dim + hidden_dim)。

- 计算量：

Q = self.query_proj(X), K = self.key_proj(X), V = self.value_proj(X)：每个操作涉及 batch_size * seq_len * hidden_dim * hidden_dim 次乘法和 batch_size * seq_len * hidden_dim 次加法。
attention_value = Q @ K.transpose(-1, -2)：涉及 batch_size * seq_len * seq_len * hidden_dim 次乘法和 batch_size * seq_len * seq_len * (hidden_dim - 1) 次加法。
output = torch.matmul(attention_wight, V)：涉及 batch_size * seq_len * hidden_dim * seq_len 次乘法和 batch_size * seq_len * hidden_dim * (seq_len - 1) 次加法。

In [7]:
X = torch.rand(3, 2, 4)
print(X.T.shape)

torch.Size([4, 2, 3])


##  QKV 矩阵计算的时候，可以合并成一个大矩阵计算

In [None]:
# 导入相关需要的包
import math
import torch
import torch.nn as nn
from thop import profile

class SelfAttV2(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim
        # 如果模型比较小，类似于以前的bert 这种size，一个卡就能放下的，用一个大矩阵应该快于三个小矩阵 for 循环（减少读写）
        # 但是现在是因为 llm 的tensor 已经够大了，所以会做切分，所以合成大矩阵没有意义了 
        self.proj = nn.Linear(dim, dim * 3)

        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X):
        # X shape is: (batch, seq, dim)

        QKV = self.proj(X)  # (batch, seq, dim * 3)
        # reshape 从希望的 q, k, 的形式
        Q, K, V = torch.split(QKV, self.dim, dim=-1)

        # print(x)
        att_weight = torch.softmax(
            Q @ K.transpose(-1, -2) / math.sqrt(self.dim), dim=-1
        )
        output = att_weight @ V
        return self.output_proj(output)


X = torch.rand(3, 2, 4)
net = SelfAttV2(4)
net(X).shape
# 计算 SelfAttV1 的 GFLOPs
flops_v2, params_v2 = profile(net, inputs=(X,))
print(f"SelfAttV1 FLOPs: {flops_v2 / 1e6:.6f} MFLOPs")
print(f"SelfAttV1 Parameters: {params_v2 / 1e6:.6f} M")

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
SelfAttV1 FLOPs: 0.000384 MFLOPs
SelfAttV1 Parameters: 0.000080 M


## 加入 dropout 、 attention_mask 、 output_proj

In [18]:
# 导入相关需要的包
import math
import torch
import torch.nn as nn
from thop import profile

class SelfAttV3(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim
        # 这样可以进行加速
        self.proj = nn.Linear(dim, dim * 3)
        # 一般是 0.1 的 dropout，一般写作 config.attention_probs_dropout_prob
        # hidden_dropout_prob 一般也是 0.1
        self.att_drop = nn.Dropout(0.1)

        # 不写这个应该也没人怪，应该好像是 MultiHeadAttention 中的产物，这个留给 MultiHeadAttention 也没有问题；
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X, attention_mask=None):
        # attention_mask shape is: (batch, seq)
        # X shape is: (batch, seq, dim)

        QKV = self.proj(X)  # (batch, seq, dim * 3)
        # reshape 从希望的 q, k, 的形式
        Q, K, V = torch.split(QKV, self.dim, dim=-1)

        att_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            # 给 weight 填充一个极小的值
            att_weight = att_weight.masked_fill(attention_mask == 0, float("-1e20"))

        att_weight = torch.softmax(att_weight, dim=-1)
        print(att_weight)

        # 注意：dropout 是对 att_weight 的，不是对 att_weight @ V 的
        att_weight = self.att_drop(att_weight)

        output = att_weight @ V
        ret = self.output_proj(output)
        return ret


X = torch.rand(3, 4, 2)
b = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
    ]
)
print(b.shape)
# 在第 0 维重复 1 次，在第 1 维重复 4 次，在第 2 维重复 1 次
# [3, 4] -> [3, 1, 4] -> [3, 4, 4]]
mask = b.unsqueeze(dim=1).repeat(1, 4, 1)
print(mask, mask.shape)

net = SelfAttV3(2)
net(X, mask).shape

torch.Size([3, 4])
tensor([[[1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0]],

        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0]],

        [[1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0]]]) torch.Size([3, 4, 4])
tensor([[[0.3324, 0.3369, 0.3307, 0.0000],
         [0.3390, 0.3296, 0.3315, 0.0000],
         [0.3435, 0.3298, 0.3267, 0.0000],
         [0.3307, 0.3395, 0.3299, 0.0000]],

        [[0.4698, 0.5302, 0.0000, 0.0000],
         [0.4999, 0.5001, 0.0000, 0.0000],
         [0.4922, 0.5078, 0.0000, 0.0000],
         [0.4609, 0.5391, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)


torch.Size([3, 4, 2])

## 面试写法

In [None]:
# 导入相关需要的包
import math
import torch
import torch.nn as nn

# import warnings

# warnings.filterwarnings(action="ignore")

class SelfAttV4(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.dim = dim

        # 这样很清晰
        self.query_proj = nn.Linear(dim, dim)
        self.key_proj = nn.Linear(dim, dim)
        self.value_proj = nn.Linear(dim, dim)
        # 一般是 0.1 的 dropout，一般写作 config.attention_probs_dropout_prob
        # hidden_dropout_prob 一般也是 0.1
        self.att_drop = nn.Dropout(0.1)

        # 可以不写；具体和面试官沟通。
        # 这是 MultiHeadAttention 中的产物，这个留给 MultiHeadAttention 也没有问题；
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X, attention_mask=None):
        # attention_mask shape is: (batch, seq)
        # X shape is: (batch, seq, dim)

        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        att_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            # 给 weight 填充一个极小的值
            att_weight = att_weight.masked_fill(attention_mask == 0, float("-1e20"))

        att_weight = torch.softmax(att_weight, dim=-1)
        print(f"att_weight:\n {att_weight}")

        # 注意：dropout 是对 att_weight 的，不是对 att_weight @ V 的
        att_weight = self.att_drop(att_weight)

        output = att_weight @ V
        ret = self.output_proj(output)
        return ret


X = torch.rand(3, 4, 2)
b = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
    ]
)
# print(b.shape)
# 在第 0 维重复 1 次，在第 1 维重复 4 次，在第 2 维重复 1 次
# [3, 4] -> [3, 1, 4] -> [3, 4, 4]]
mask = b.unsqueeze(dim=1).repeat(1, 4, 1)

net = SelfAttV4(2)
net(X, mask).shape

att_weight:
 tensor([[[0.3460, 0.3163, 0.3377, 0.0000],
         [0.3147, 0.3550, 0.3303, 0.0000],
         [0.3287, 0.3368, 0.3345, 0.0000],
         [0.3105, 0.3601, 0.3294, 0.0000]],

        [[0.4935, 0.5065, 0.0000, 0.0000],
         [0.4896, 0.5104, 0.0000, 0.0000],
         [0.4887, 0.5113, 0.0000, 0.0000],
         [0.4988, 0.5012, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)


torch.Size([3, 4, 2])

# MHSA

In [3]:
import math
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head) -> None:
        super().__init__()
        self.nums_head = nums_head

        # 一般来说，
        self.head_dim = hidden_dim // nums_head
        self.hidden_dim = hidden_dim

        # 一般默认有 bias，需要时刻主意，hidden_dim = head_dim * nums_head，所以最终是可以算成是 n 个矩阵
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

        # gpt2 和 bert 类都有，但是 llama 其实没有
        self.att_dropout = nn.Dropout(0.1)
        # 输出时候的 proj
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X, attention_mask=None):
        # 需要在 mask 之前 masked_fill
        # X shape is (batch, seq, hidden_dim)
        # attention_mask shape is (batch, seq)

        batch_size, seq_len, hidden_dim = X.size()

        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        print(f"Q: {Q.shape}")
        # （batch_size, seq_len, num_head, head_dim）->（batch_size, num_head, seq_len, head_dim）
        q_state = Q.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(
            1, 2
        )
        print(f"q_state: {q_state.shape}")
        k_state = K.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(
            1, 2
        )
        v_state = V.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(
            1, 2
        )
        # 主意这里需要用 head_dim，而不是 hidden_dim
        attention_weight = (
            q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)
        )
        print(f"attention_weight:", attention_weight.shape)
        # print(type(attention_mask))
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float("-1e20")
            )

        # 第四个维度 softmax
        # （batch_size, num_head, seq_len, seq_len）
        attention_weight = torch.softmax(attention_weight, dim=3)
        # print(attention_weight, attention_weight)

        # 注意：dropout 是对 att_weight 的，不是对 att_weight @ V 的
        attention_weight = self.att_dropout(attention_weight)
        output_mid = attention_weight @ v_state
        print(output_mid.shape)

        # 重新变成 (batch, seq_len, num_head, head_dim)
        # 这里的 contiguous() 是相当于返回一个连续内存的 tensor，一般用了 permute/tranpose 都要这么操作
        # 如果后面用 Reshape 就可以不用这个 contiguous()，因为 view 只能在连续内存中操作
        output_mid = output_mid.transpose(1, 2).contiguous()

        # 变成 (batch, seq, hidden_dim),
        output = output_mid.view(batch_size, seq_len, -1)
        output = self.o_proj(output)
        return output


# attention_mask = (
#     torch.tensor(
#         [
#             [0, 1],
#             [0, 0],
#             [1, 0],
#         ]
#     )
#     .unsqueeze(1)
#     .unsqueeze(2)
#     .expand(3, 8, 2, 2)
# )
attention_mask = torch.randint(0, 2, (3, 8, 2, 2))

x = torch.rand(3, 2, 128)
net = MultiHeadAttention(128, 8)
net(x, attention_mask).shape

Q: torch.Size([3, 2, 128])
q_state: torch.Size([3, 8, 2, 16])
attention_weight: torch.Size([3, 8, 2, 2])
torch.Size([3, 8, 2, 16])


torch.Size([3, 2, 128])

# Cross Attention

In [None]:
import math
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head) -> None:
        super().__init__()
        self.nums_head = nums_head
        self.head_dim = hidden_dim // nums_head
        self.hidden_dim = hidden_dim

        # 定义线性层来投影查询、键和值
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

        # 定义dropout层
        self.att_dropout = nn.Dropout(0.1)
        # 定义输出层
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, Q, K, V, attention_mask=None):
        # Q shape is (batch, seq_len_q, hidden_dim)
        # K shape is (batch, seq_len_k, hidden_dim)
        # V shape is (batch, seq_len_k, hidden_dim)
        # attention_mask shape is (batch, seq_len_q, seq_len_k)

        batch_size, seq_len_q, _ = Q.size()
        _, seq_len_k, _ = K.size()

        # 投影查询、键和值
        Q = self.q_proj(Q)
        K = self.k_proj(K)
        V = self.v_proj(V)

        print(f"Q after projection: {Q.shape}")
        print(f"K after projection: {K.shape}")
        print(f"V after projection: {V.shape}")

        # 将查询、键和值重塑为多头形式
        q_state = Q.view(batch_size, seq_len_q, self.nums_head, self.head_dim).transpose(1, 2)
        k_state = K.view(batch_size, seq_len_k, self.nums_head, self.head_dim).transpose(1, 2)
        v_state = V.view(batch_size, seq_len_k, self.nums_head, self.head_dim).transpose(1, 2)

        print(f"q_state: {q_state.shape}")
        print(f"k_state: {k_state.shape}")
        print(f"v_state: {v_state.shape}")

        # 计算注意力权重
        attention_weight = (q_state @ k_state.transpose(-1, -2)) / math.sqrt(self.head_dim)

        print(f"attention_weight before mask: {attention_weight.shape}")

        # 应用注意力掩码
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(attention_mask == 0, float("-1e20"))

        print(f"attention_weight after mask: {attention_weight.shape}")

        # 对注意力权重进行softmax操作
        attention_weight = torch.softmax(attention_weight, dim=-1)

        print(f"attention_weight after softmax: {attention_weight.shape}")

        # 应用dropout
        attention_weight = self.att_dropout(attention_weight)

        # 计算输出
        output_mid = attention_weight @ v_state

        print(f"output_mid: {output_mid.shape}")

        # 重塑输出为(batch, seq_len_q, hidden_dim)
        output_mid = output_mid.transpose(1, 2).contiguous()
        output = output_mid.view(batch_size, seq_len_q, -1)
        output = self.o_proj(output)

        print(f"output: {output.shape}")

        return output

# 测试用例
if __name__ == "__main__":
    # 定义输入序列 
    # Cross-Attention中，查询（Query）通常来自于一个序列（如文本序列），
    # 而键（Key）和值（Value）来自于另一个序列（如另一个文本序列或图像特征）
    Q = torch.rand(3, 4, 128)  # 查询序列，形状为 (batch_size, seq_len_q, hidden_dim)
    K = torch.rand(3, 6, 128)  # 键序列，形状为 (batch_size, seq_len_k, hidden_dim)
    V = torch.rand(3, 6, 128)  # 值序列，形状为 (batch_size, seq_len_k, hidden_dim)

    # 定义注意力掩码
    attention_mask = torch.randint(0, 2, (3, 8, 4, 6))

    print(f"Q shape: {Q.shape}")
    print(f"K shape: {K.shape}")
    print(f"V shape: {V.shape}")
    print(f"attention_mask shape: {attention_mask.shape}")

    # 创建CrossAttention实例
    cross_attention = CrossAttention(hidden_dim=128, nums_head=8)

    # 前向传播
    output = cross_attention(Q, K, V, attention_mask)

    # 打印输出形状以验证实现是否正确
    print(f"Output shape: {output.shape}")

Q shape: torch.Size([3, 4, 128])
K shape: torch.Size([3, 6, 128])
V shape: torch.Size([3, 6, 128])
attention_mask shape: torch.Size([3, 8, 4, 6])
Q after projection: torch.Size([3, 4, 128])
K after projection: torch.Size([3, 6, 128])
V after projection: torch.Size([3, 6, 128])
q_state: torch.Size([3, 8, 4, 16])
k_state: torch.Size([3, 8, 6, 16])
v_state: torch.Size([3, 8, 6, 16])
attention_weight before mask: torch.Size([3, 8, 4, 6])
attention_weight after mask: torch.Size([3, 8, 4, 6])
attention_weight after softmax: torch.Size([3, 8, 4, 6])
output_mid: torch.Size([3, 8, 4, 16])
output: torch.Size([3, 4, 128])
Output shape: torch.Size([3, 4, 128])


In [2]:

import math
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head) -> None:
        super().__init__()
        self.nums_head = nums_head
        self.head_dim = hidden_dim // nums_head
        self.hidden_dim = hidden_dim

        # 定义线性层来投影查询、键和值
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

        # 定义dropout层
        self.att_dropout = nn.Dropout(0.1)
        # 定义输出层
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, Q, K, V):
        # Q shape is (batch, seq_len_q, hidden_dim)
        # K shape is (batch, seq_len_k, hidden_dim)
        # V shape is (batch, seq_len_k, hidden_dim)
        # attention_mask shape is (batch, seq_len_q, seq_len_k)

        batch_size, seq_len_q, _ = Q.size()
        _, seq_len_k, _ = K.size()

        # 投影查询、键和值
        Q = self.q_proj(Q)
        K = self.k_proj(K)
        V = self.v_proj(V)

        # 将查询、键和值重塑为多头形式
        q_state = Q.view(batch_size, seq_len_q, self.nums_head, self.head_dim).transpose(1, 2)
        k_state = K.view(batch_size, seq_len_k, self.nums_head, self.head_dim).transpose(1, 2)
        v_state = V.view(batch_size, seq_len_k, self.nums_head, self.head_dim).transpose(1, 2)

        # 计算注意力权重
        attention_weight = (q_state @ k_state.transpose(-1, -2)) / math.sqrt(self.head_dim)

        # 对注意力权重进行softmax操作
        attention_weight = torch.softmax(attention_weight, dim=-1)

        # 应用dropout
        attention_weight = self.att_dropout(attention_weight)

        # 计算输出
        output_mid = attention_weight @ v_state

        # 重塑输出为(batch, seq_len_q, hidden_dim)
        output_mid = output_mid.transpose(1, 2).contiguous()
        output = output_mid.view(batch_size, seq_len_q, -1)
        output = self.o_proj(output)

        return output

# 测试用例
if __name__ == "__main__":
    # 定义输入序列 
    # Cross-Attention中，查询（Query）通常来自于一个序列（如文本序列），
    # 而键（Key）和值（Value）来自于另一个序列（如另一个文本序列或图像特征）
    Q = torch.rand(3, 4, 128)  # 查询序列，形状为 (batch_size, seq_len_q, hidden_dim)
    K = torch.rand(3, 6, 128)  # 键序列，形状为 (batch_size, seq_len_k, hidden_dim)
    V = torch.rand(3, 6, 128)  # 值序列，形状为 (batch_size, seq_len_k, hidden_dim)

    # 创建CrossAttention实例
    cross_attention = CrossAttention(hidden_dim=128, nums_head=8)

    # 前向传播
    output = cross_attention(Q, K, V)

    # 打印输出形状以验证实现是否正确
    print(f"Output shape: {output.shape}")

Output shape: torch.Size([3, 4, 128])
