In [148]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from typing import Optional, Tuple, Union 
import math
import warnings
from einops import rearrange

# 单头注意力机制 (ScaledDotProductAttention)

<img src="./images/attention.png" alt="示例图片" width="700">

$$
Attention(Q, K, V) = [softmax(\frac{QK^{\top}}{\sqrt{d_k}})] V
$$

In [10]:
def attention(query, key, value, mask=None, dropout=None):
    '''
    Compute 'Scaled Dot Product Attention'.
    The input consists of queries and keys of dimension d_k, and values of dimension d_v.
    K = W_k * x
    x: [batch_size, ..., features]
    Q, K: [batch_size, ... , features]
    '''
    
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    alpha_attn = scores.softmax(dim=-1)
    if dropout is not None:
        alpha_attn = dropout(alpha_attn)
    return torch.matmul(alpha_attn, value), alpha_attn

In [11]:
class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention """
    def __init__(self, scale):
        super().__init__()
        self.scale = scale
        self.softmax = nn.Softmax(dim=2)
    
    def forward(self, q, k, v, mask=None):
        # q, k, v: [b, l, d]
        # u: [b, l_q, l_k]
        u = torch.bmm(q, k.transpose(1, 2)) # 1. Matmul
        # torch.bmm() 是一个用于批量矩阵乘法（Batch Matrix Multiplication）的函数。它的全名是 batch matrix multiplication
        # torch.bmm(input, mat2, out=None) -> Tensor
        # input：一个 3D 张量，形状为 (b, n, m)，表示一个批次中 b 个矩阵，每个矩阵的维度为 (n, m)。
        # mat2：一个 3D 张量，形状为 (b, m, p)，表示另一个批次中 b 个矩阵，每个矩阵的维度为 (m, p)
        # 返回一个 3D 张量，形状为 (b, n, p)，即每个批次中的矩阵相乘后的结果。
        u = u / self.scale # 2. Sclae
        
        if mask is not None:
            u = u.masked_fill(mask, -np.inf) # 3. Mask
            
        attn = self.softmax(u) # 4. Softmax
        # attn: [b, l_q, l_k]
        output = torch.bmm(attn, v) # Output
        # output: [b, l_q, d_v]
        
        return attn, output

In [12]:
if __name__ == "__main__":
    n_q, n_k, n_v = 2, 4, 4
    d_q, d_k, d_v = 128, 128, 64
    batch = 4
    
    q = torch.randn(batch, n_q, d_q)
    k = torch.randn(batch, n_k, d_k)
    v = torch.randn(batch, n_v, d_v)
    mask = torch.zeros(batch, n_q, n_k).bool()
    
    attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
    attn, output = attention(q, k, v, mask=mask)
    
    print(attn.shape)
    print(attn)
    print(output.shape)
    print(output)
    

torch.Size([4, 2, 4])
tensor([[[0.1412, 0.5310, 0.1467, 0.1812],
         [0.1449, 0.0711, 0.6258, 0.1582]],

        [[0.2215, 0.2577, 0.3899, 0.1310],
         [0.5084, 0.2397, 0.0741, 0.1778]],

        [[0.3534, 0.2216, 0.0810, 0.3440],
         [0.3795, 0.3106, 0.1051, 0.2048]],

        [[0.2082, 0.0021, 0.7299, 0.0598],
         [0.1500, 0.2821, 0.0438, 0.5241]]])
torch.Size([4, 2, 64])
tensor([[[-3.9275e-01,  2.6675e-01,  1.2438e+00,  1.0520e+00,  1.1428e-01,
           3.7404e-01, -7.1114e-01, -5.8611e-01, -2.2759e-01,  2.0379e-01,
           4.8823e-01, -3.5633e-01, -6.4568e-01,  3.4941e-01, -5.0592e-01,
           5.2288e-01, -1.5472e-01,  6.6828e-01, -9.9948e-01, -7.6306e-01,
           3.9175e-01,  5.2225e-01, -5.6732e-01, -8.5301e-02, -2.0100e-01,
          -5.2807e-01,  1.3071e-01, -2.2853e+00, -4.6790e-01, -3.4336e-01,
          -4.7688e-01,  1.7989e-01, -2.8195e-01, -9.1272e-01, -7.9047e-01,
           1.8103e-01, -2.8280e-01, -7.2264e-02,  3.7008e-01, -4.7586e-01,
   

# 多头注意力机制 (Multi Head Attention)

<img src="./images/MultiHeadAttention.png" alt="示例图片" width="700">

$MultiHead(Q, K, V) = Concat(head_1, head_2, ... , head_h)W^O\text{, where }head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)$

$W_i^Q \in \mathbb{R}^{d_{model} \times d_k}, \quad W_i^K \in \mathbb{R}^{d_{model} \times d_K}, \quad W_i^V \in \mathbb{R}^{d_{model} \times d_V}, \quad W_i^O \in \mathbb{R}^{hd_v \times d_{model}}$

In [13]:
class MultiHeadAttention_ori(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention_ori, self).__init__()
        
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h 
        self.h = h
        # linears: W_i^Q, W_i^K, W_i^V, W_i^O
        self.linears = clones(nn.Linear(d_model, d_model), 4) 
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        '''
        Implements Figure
        '''
        if mask is not None:
            # Some mask applied to all h heads
            mask = mask.unsqueeze(1)
        # K = W_k * x
        # x: [batch_size, ..., features]
        # Q, K: [batch_size, ... , features]
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for lin, x in zip(self.linears, (query, key, value))
        ]
        
        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear
        x = (
            x.transpose(1, 2)
            .contiguoous()
            .view(nbatches, -1, self.h * self.d_k)
        )
        
        del query
        del key
        del value
        
        return self.linears[-1](x) # W_i^O

In [16]:
class MultiHeadAttention(nn.Module):
    """ Multi-Head Attention """
    def __init__(self, n_head, d_k_, d_v_, d_k, d_v, d_o):
        '''
        d_k_, d_v_: 这是输入Q, V矩阵的维度。
        d_k: 每个头的查询和键的维度。
        d_q_ = d_k_ ?
        d_q = d_k ?
        '''
        super().__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.fc_q = nn.Linear(d_k_, n_head * d_k)
        self.fc_k = nn.Linear(d_k_, n_head * d_k)
        self.fc_v = nn.Linear(d_v_, n_head * d_v)
        # 这三个全连接层用于将输入的查询（Q）、键（K）、值（V）映射到多个注意力头的空间
        
        self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
        
        self.fc_o = nn.Linear(n_head * d_v, d_o)
        
    def forward(self, q, k, v, mask=None):
        n_head, d_q, d_k, d_v = self.n_head, self.d_k, self.d_k, self.d_v
        batch, n_q, d_q_ = q.size()
        batch, n_k, d_k_ = k.size()
        batch, n_v, d_v_ = v.size()
        
        q = self.fc_q(q) # 1. 单头变多头
        k = self.fc_k(k)
        v = self.fc_v(v)
        
        # 重塑为多头的形状
        # q, k, v: [b, l, h*d] -> [b, l, h, d] -> [h, b, l, d] -> [b*h, l, d]
        q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)
        k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)
        v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, d_v)
        
        if mask is not None:
            mask = mask.repeat(n_head, 1, 1)
        
        attn, output = self.attention(q, k, v, mask=mask) # 2. 当成单头注意力求输出
        # attn: [b*h, l_q, l_k]
        
        # output: [b*h, l_q, d_v] -> [b, h, l_q, d_v] -> [h, l_q, b, d_v] -> [b, l_q, h*d_v]
        output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1) # 3. Concat
        
        # [..., h*d_v] -> [..., d_o]
        output = self.fc_o(output) # 4. 仿射变换得到最终输出
        
        return attn, output

In [22]:
if __name__ == "__main__":
    n_q, n_k, n_v = 2, 4, 4
    d_q_, d_k_, d_v_ = 128, 128, 64
    batch = 4
    
    q = torch.randn(batch, n_q, d_q_)
    k = torch.randn(batch, n_k, d_k_)
    v = torch.randn(batch, n_v, d_v_)
    mask = torch.zeros(batch, n_q, n_k).bool()
    
    mha = MultiHeadAttention(n_head=8, d_k_=128, d_v_=64, d_k=256, d_v=128, d_o=128)
    attn, output = mha(q, k, v, mask=mask)
    
    print(attn.shape)
    print(output.shape)
    print(attn)
    print(output)

torch.Size([32, 2, 4])
torch.Size([4, 2, 128])
tensor([[[0.2009, 0.3082, 0.2627, 0.2283],
         [0.3128, 0.3074, 0.2294, 0.1505]],

        [[0.3394, 0.2519, 0.2555, 0.1533],
         [0.2826, 0.1722, 0.3127, 0.2325]],

        [[0.1779, 0.4567, 0.1665, 0.1989],
         [0.2524, 0.2534, 0.3248, 0.1694]],

        [[0.2794, 0.3056, 0.2340, 0.1810],
         [0.2835, 0.1834, 0.2774, 0.2557]],

        [[0.2392, 0.3608, 0.2344, 0.1657],
         [0.2105, 0.3367, 0.2348, 0.2181]],

        [[0.2490, 0.1739, 0.2463, 0.3308],
         [0.2697, 0.3590, 0.2173, 0.1540]],

        [[0.2851, 0.2632, 0.2858, 0.1658],
         [0.2566, 0.3526, 0.2540, 0.1368]],

        [[0.1978, 0.1414, 0.2378, 0.4230],
         [0.3562, 0.2322, 0.1648, 0.2468]],

        [[0.1987, 0.2720, 0.2849, 0.2444],
         [0.1918, 0.1703, 0.3279, 0.3099]],

        [[0.3124, 0.3550, 0.1565, 0.1760],
         [0.2070, 0.2443, 0.3431, 0.2056]],

        [[0.2701, 0.2096, 0.2996, 0.2207],
         [0.2853, 0.2200, 0.36

# 自注意力机制 (SelfAttention)

<img src="./images/Decoder2.png" alt="示例图片" width="900">

Self-Attention, 和 Attention 类似, 他们都是一种注意力机制. 不同的是 Attention 是 source 对 target, 输入的 source 和输出的 target 内容不同. 例如英译中, 输入英文, 输出中文. 而 Self-Attention 是 source 对 source, 是 source 内部元素之间或者 target 内部元素之间发生的 Attention 机制, 也可以理解为 Target = Source 这种特殊情况下的注意力机制.

In [43]:
class SelfAttention(nn.Module):
    """ Self-Attention """
    def __init__(self, n_head, d_k, d_v, d_x, d_o):
        super(SelfAttention, self).__init__()  # 调用父类的 __init__ 方法，必须放在参数初始化之前
        self.wq = nn.Parameter(torch.Tensor(d_x, d_k))
        self.wk = nn.Parameter(torch.Tensor(d_x, d_k))
        self.wv = nn.Parameter(torch.Tensor(d_x, d_v))
        
        self.mha = MultiHeadAttention(n_head=n_head, d_k_=d_k, d_v_=d_v, d_k=d_k, d_v=d_v, d_o=d_o)
        
        self.init_parameters()
        
    def init_parameters(self):
        # self.parameters() 返回的是模型中所有可训练参数的迭代器，其中每个参数（例如权重和偏置）都是一个 Tensor。
        for param in self.parameters():
            # 计算初始化标准差，通常使用参数的输入维度
            # 单独一层的 param 的 shape 为: (out_features, in_features)
            stdv = 1. / np.power(param.size(-1), 0.5)
            # 使用均匀分布在 [-stdv, stdv] 范围内初始化参数
            param.data.uniform_(-stdv, stdv)
    def forward(self, x, mask=None):
        q = torch.matmul(x, self.wq)
        k = torch.matmul(x, self.wk)
        v = torch.matmul(x, self.wv)
        
        attn, output = self.mha(q, k, v, mask=mask)
        # attn: [b*h, l, l]
        # output: [b, l, d_o]
        
        return attn, output

In [44]:
if __name__ == "__main__":
    n_x = 4
    d_x = 80
    batch = 4
    x = torch.randn(batch, n_x, d_x)
    mask = torch.zeros(batch, n_x, n_x).bool()
    
    selfattn = SelfAttention(n_head=8, d_k=128, d_v=64, d_x=80, d_o=80)
    attn, output = selfattn(x, mask=mask)
    print(attn.shape)
    print(output.shape)
    print(attn)
    print(output)

torch.Size([32, 4, 4])
torch.Size([4, 4, 80])
tensor([[[0.2377, 0.2362, 0.2598, 0.2663],
         [0.2773, 0.2442, 0.2359, 0.2426],
         [0.2324, 0.2502, 0.2457, 0.2717],
         [0.2597, 0.2455, 0.2584, 0.2364]],

        [[0.2499, 0.2387, 0.2506, 0.2608],
         [0.2551, 0.2465, 0.2406, 0.2579],
         [0.2815, 0.2339, 0.2320, 0.2527],
         [0.2465, 0.2688, 0.2514, 0.2333]],

        [[0.2485, 0.2578, 0.2538, 0.2398],
         [0.2526, 0.2506, 0.2510, 0.2458],
         [0.2360, 0.2568, 0.2537, 0.2534],
         [0.2702, 0.2376, 0.2361, 0.2561]],

        [[0.2391, 0.2506, 0.2865, 0.2238],
         [0.2611, 0.2594, 0.2313, 0.2482],
         [0.2439, 0.2736, 0.2417, 0.2408],
         [0.2483, 0.2548, 0.2583, 0.2386]],

        [[0.2126, 0.2802, 0.2462, 0.2610],
         [0.2513, 0.2513, 0.2476, 0.2497],
         [0.2218, 0.2782, 0.2158, 0.2842],
         [0.2394, 0.2417, 0.2561, 0.2629]],

        [[0.2739, 0.2509, 0.2433, 0.2319],
         [0.2531, 0.2337, 0.2626, 0.2506]

# KV Cache

<img src="./images/KVCache.png" alt="示例图片" width="900">

In [52]:

class GPT2Attention(nn.Module):
    
    ...
    
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_masl: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        if encoder_hidden_states is not None:
            if not hasattr(self, "q_attn"):
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Plwase make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                )
            query = self.q_attn(hidden_states)
            key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
            attention_mask = encoder_attention_mask
        
        else:
            query, key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
            
            query = self._split_heads(query, self.num_heads, self.head_dim)
            key = self._split_heads(key, self.num_heads, self.head_dim)
            value = self._split_heads(value, self.num_heads, self.head_dim)
            
            # 过去所存的值
            if layer_past is not None:
                past_key, past_value = layer_past
                key = torch.cat((past_key, key), dim=-2) # 把当前新的 key 加入
                value = torch.cat((past_value. value), dim=-1) # 把当前新的 value 加入
                
            if use_cache is True:
                present = (key, value) # 输出用于保存
            else:
                present = None
            
            if self.reorder_and_upcast_attn:
                attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
            else:
                attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
            
            attn_output = self.merge_heads(attn_output, self.num_heads, self.head_dim)
            attn_output = self.c_proj(attn_output)
            attn_output = self.resid_dropout(attn_output)
            
            outputs = (attn_output, present)
            if output_attentions:
                outputs += (attn_weights, )
            
        return outputs # a, present, (attentions)

In [142]:
class SelfAttention_KVCache(nn.Module):
    """ Self-Attention """
    def __init__(self, n_head, d_k, d_v, d_x, d_o):
        super(SelfAttention_KVCache, self).__init__()  # 调用父类的 __init__ 方法，必须放在参数初始化之前
        self.wq = nn.Parameter(torch.Tensor(d_x, d_k))
        self.wk = nn.Parameter(torch.Tensor(d_x, d_k))
        self.wv = nn.Parameter(torch.Tensor(d_x, d_v))
        
        self.mha = MultiHeadAttention(n_head=n_head, d_k_=d_k, d_v_=d_v, d_k=d_k, d_v=d_v, d_o=d_o)
        
        self.init_parameters()
        
        # KV cache to store key and value during autoregressive generation
        self.cache_k = None
        self.cache_v = None
        
    def init_parameters(self):
        # self.parameters() 返回的是模型中所有可训练参数的迭代器，其中每个参数（例如权重和偏置）都是一个 Tensor。
        for param in self.parameters():
            # 计算初始化标准差，通常使用参数的输入维度
            # 单独一层的 param 的 shape 为: (out_features, in_features)
            stdv = 1. / np.power(param.size(-1), 0.5)
            # 使用均匀分布在 [-stdv, stdv] 范围内初始化参数
            param.data.uniform_(-stdv, stdv)
    def forward(self, x, mask=None):
        
        q = torch.matmul(x, self.wq)
        k = torch.matmul(x, self.wk)
        v = torch.matmul(x, self.wv)
        
        attn, output = self.mha(q, k, v, mask=mask)
        # attn: [b*h, l, l]
        # output: [b, l, d_o]
        
        return attn, output
    def forward(self, x, mask=None, use_cache=False):
        
        if use_cache:
            
            # Step 1: Initialize q, k, v
            q = torch.matmul(x[:, -1:, :], self.wq)  # Query tensor for the last token only [batch_size, 1, d_k]
            k = torch.matmul(x[:, -1:, :], self.wk)  # Key tensor for all previous tokens [batch_size, seq_len-1, d_k]
            v = torch.matmul(x[:, -1:, :], self.wv)  # Value tensor for all previous tokens [batch_size, seq_len-1, d_v]
            
            # Step 2: Caching mechanism for autoregressive generation
            if self.cache_k is None:  # First call, initialize cache
                self.cache_k = k
                self.cache_v = v
                
            else:  # Subsequent calls, concatenate new keys and values to cache
                self.cache_k = torch.cat([self.cache_k, k], dim=1)  # Append new keys
                self.cache_v = torch.cat([self.cache_v, v], dim=1)  # Append new values
            
            k, v = self.cache_k, self.cache_v  # Use the cached keys and values
        
        else:
            q = torch.matmul(x, self.wq)
            k = torch.matmul(x, self.wk)
            v = torch.matmul(x, self.wv)

        # Step 3: Multi-Head Attention
        attn, output = self.mha(q, k, v, mask=mask)  # Attention output
        # attn: [b*h, l, l]
        # output: [b, l, d_o]
        
        # Output: [batch_size, seq_len, d_o]
        return attn, output, self.cache_k, self.cache_v
    

In [145]:
if __name__ == "__main__":
    # 创建一个简单的输入张量
    batch_size = 2
    seq_len = 5
    d_x = 16  # 输入维度
    d_k = 8   # Key 和 Query 维度
    d_v = 8   # Value 维度
    d_o = 16  # 输出维度
    n_head = 2  # 头数

    # 输入的随机张量 [batch_size, seq_len, d_x]
    x = torch.randn(batch_size, seq_len, d_x)

    # 定义模型
    model = SelfAttention_KVCache(n_head=n_head, d_k=d_k, d_v=d_v, d_x=d_x, d_o=d_o)

    # 测试是否能够正常计算输出
    mask = None  # 可以在生成时定义一个mask
    
    print("No KV Cache")
    use_cache = False
    attn, output, _, _ = model(x, mask=mask, use_cache=use_cache)
    print(f"Attention Shape: {attn.shape}, Output Shape: {output.shape}")

    print("\n\nKV Cache")
    # 测试自回归生成：接下来使用缓存进行推理
    use_cache = True  # 启用缓存

    # 模拟自回归生成过程
    size = []
    for i in range(1, seq_len + 1):  # 假设逐步生成
        partial_x = x[:, :i, :]  # 只使用前i个token
        attn, output, k_cache, v_cache = model(partial_x, mask=mask, use_cache=use_cache)
        print(f"Step {i}: Attention Shape: {attn.shape}, Output Shape: {output.shape}")
        size.append(f"Step {i}: K Cache Shape: {k_cache.shape}, V Cache Shape: {v_cache.shape}")
    # 模拟自回归生成过程
    print("因为每次只查询一个 token, 所以 attention 和 output 的 dim=1 一直等于1\n")
    for i in range(1, seq_len + 1):  # 假设逐步生成
        print(size[i-1])

No KV Cache
Attention Shape: torch.Size([4, 5, 5]), Output Shape: torch.Size([2, 5, 16])


KV Cache
Step 1: Attention Shape: torch.Size([4, 1, 1]), Output Shape: torch.Size([2, 1, 16])
Step 2: Attention Shape: torch.Size([4, 1, 2]), Output Shape: torch.Size([2, 1, 16])
Step 3: Attention Shape: torch.Size([4, 1, 3]), Output Shape: torch.Size([2, 1, 16])
Step 4: Attention Shape: torch.Size([4, 1, 4]), Output Shape: torch.Size([2, 1, 16])
Step 5: Attention Shape: torch.Size([4, 1, 5]), Output Shape: torch.Size([2, 1, 16])
因为每次只查询一个 token, 所以 attention 和 output 的 dim=1 一直等于1

Step 1: K Cache Shape: torch.Size([2, 1, 8]), V Cache Shape: torch.Size([2, 1, 8])
Step 2: K Cache Shape: torch.Size([2, 2, 8]), V Cache Shape: torch.Size([2, 2, 8])
Step 3: K Cache Shape: torch.Size([2, 3, 8]), V Cache Shape: torch.Size([2, 3, 8])
Step 4: K Cache Shape: torch.Size([2, 4, 8]), V Cache Shape: torch.Size([2, 4, 8])
Step 5: K Cache Shape: torch.Size([2, 5, 8]), V Cache Shape: torch.Size([2, 5, 8])


# 多队列注意力机制 (MQA)

<img src="./images/MQA.png" alt="示例图片" width="900">

In [151]:
def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
    k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
    v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
    min_val = torch.finfo(q.dtype).min
    (b, _, s_q, d) = q.shape
    s_k = k.size(-1)
    if softmax_scale is None:
        softmax_scale = 1 / math.sqrt(d)
    attn_weight = q.matmul(k) * softmax_scale
    if attn_bias is not None:
        if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
            raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
        attn_weight = attn_weight + attn_bias
    if key_padding_mask is not None:
        if attn_bias is not None:
            warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
        attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
    if is_causal:
        s = max(s_q, s_k)
        causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
        causal_mask = causal_mask.tril()
        causal_mask = causal_mask.to(torch.bool)
        causal_mask = ~causal_mask
        causal_mask = causal_mask[-s_q:, -s_k:]
        attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
    attn_weight = torch.softmax(attn_weight, dim=-1)
    if dropout_p:
        attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
    out = attn_weight.matmul(v)
    out = rearrange(out, 'b h s d -> b s (h d)')
    if needs_weights:
        return (out, attn_weight)
    return (out, None)

In [152]:
class MultiQueryAttention(nn.Module):
    '''
    Multi-Query self attention
    
    Using torch or triton attention implementation enables user to also use additive bias.
    使用 Torch 或 Triton 的注意力实现可以让用户同时使用加性偏置(additive bias)。
    '''
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        attn_impl: str = 'triton', # 注意力实现的方式，可以是 'triton' 或其他实现方式。这会决定使用哪个后端（比如 Triton 后端加速）。
        clip_qkv: Optional[float] = None, # 可选的值，表示在计算 Q、K、V 时对结果进行裁剪。裁剪的目的是防止数值过大。
        qk_ln: bool = False, # 一个布尔值，表示是否对查询（Q）和键（K）进行层归一化（LayerNorm）。
        softmax_scale: Optional[float] = None,
        attn_pdrop: float = 0.0,
        low_precision_layernorm: bool = False, # 是否使用低精度的层归一化
        verbose: int = 0, # 调试时是否输出详细信息。
        device: Optional[str] = None
    ):
        super().__init__()
        self.attn_impl = attn_impl
        self.clip_qkv = clip_qkv
        self.qk_ln = qk_ln
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.softmax_scale = softmax_scale
        if self.softmax_scale is None:
            self.softmax_scale = 1 / math.sqrt(self.head_dim)
        self.attn_dropout_p = attn_pdrop
        
        # 这个线性层用于从输入 x 中计算出 Q、K 和 V。
        # 它的输出维度是 d_model + 2 * head_dim，这意味着它一次性计算所有的 Q、K 和 V 值。
        self.Wqkv = nn.Linear(
            d_model,
            d_model + 2 * self.head_dim,
            device=device,
        )
        
        # 是用于在某些特定硬件（例如，使用加速库如 Triton 或 CUDA）上对模型的参数进行优化的一个技巧。
        # 具体来说，这段代码的目的是 启用 Wqkv 线性层的权重融合，以加速计算。
        # _fused 是一个非标准的属性，用于标记 Wqkv 权重的某些优化设置，指示如何在硬件层面进行 权重融合。
        # (0, fuse_splits) 表示的是权重矩阵在某个维度上进行分割的方式。
        # fuse_splits 会告诉系统如何将 Q、K、V 的权重矩阵分割成两个部分，这样可以利用硬件加速并行计算的能力。
        # 使用 # type: ignore 来让类型检查器忽略对这一行代码的警告。
        # # type: ignore 是一种 类型注释，用于告诉 类型检查器（如 mypy 或编辑器的静态类型检查器）忽略当前行的类型错误或警告。
        # 它并不是传统意义上的注释，而是一种特定于类型检查的语法。
        fuse_splits = (d_model, d_model + self.head_dim)
        self.Wqkv._fused = (0, fuse_splits) # type: ignore
        self.attn_fn = scaled_multihead_dot_product_attention
        # 将注意力的结果映射到最终的输出空间。
        self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
        
        # 实现残差连接的功能
        self.out_proj._is_residual = True # type: ignore
        
    
    def forward(
        self, 
        x, # [batch_size, seq_len, d_model]
        past_key_value=None, # 上一次计算的 key 和 value，用于加速推理时的自回归过程（缓存之前的注意力结果）
        attn_bias=None, # 可选的注意力偏置（用于掩蔽等）
        attention_mask=None,
        is_causal=True, # 是否使用因果掩蔽（自回归生成时使用），即防止当前位置关注到未来的信息。
        needs_weoghts=False, # 是否需要返回注意力权重.
    ):
        
        # Wqkv 将输入 x 映射到 Q、K 和 V 的合并空间，输出的形状是 [batch_size, seq_len, d_model + 2 * head_dim]
        qkv = self.Wqkv(x) # (1, 512, 960)
        
        # 限制了 Q、K、V 的值范围，以避免数值过大。
        if self.clip_qkv:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        
        # 将 qkv 张量按列（维度2）分割为查询、键和值。这些分割出的张量分别对应不同的维度。
        # q -> [1, 512, 768]
        # k -> [1, 512, 96]
        # v -> [1, 512, 96]
        # 8 heads
        query, key, value =qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
        
        # 自注意力（Self-Attention）机制中用来指定哪些位置应被遮蔽（mask）的一个张量，通常用于 避免对填充位置的注意力计算。
        # key_padding_mask 主要用于处理变长序列和批处理时的填充（padding）问题，尤其是在处理不同长度的序列时。
        key_padding_mask = attention_mask
        
        # 如果启用了层归一化（qk_ln 为 True），则对查询和键分别进行层归一化。
        if self.qk_ln:
            # Applying layernorm to qk
            dtype = query.dtype
            query = self.q_ln(query).to(dtype)
            key = self.k_ln(key).to(dtype)
        
        context, attn_weights, past_key_value = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            past_key_value=past_key_value,
            softmax_scale=self.softmax_scale,
            attn_bias=attn_bias,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            dropout_p=self.attn_dropout_p,
            training=self.training,
            needs_weights=needs_weoghts,
            multiquery=True,
        )
        return self.out_proj(context, attn_weights, past_key_value)

In [188]:
class MyMultiQueryAttention(nn.Module):
    """ Multi-Head Attention """
    def __init__(self, n_head, d_k_, d_v_, d_k, d_v, d_o):
        '''
        d_k_, d_v_: 这是输入Q, V矩阵的维度。
        d_k: 每个头的查询和键的维度。
        d_q_ = d_k_ ?
        d_q = d_k ?
        '''
        super().__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.fc_q = nn.Linear(d_k_, n_head * d_k)
        # self.fc_k = nn.Linear(d_k_, n_head * d_k)
        # self.fc_v = nn.Linear(d_v_, n_head * d_v)
        self.fc_k = nn.Linear(d_k_, d_k)
        self.fc_v = nn.Linear(d_v_, d_v) # 共享一个头
        
        self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
        
        self.fc_o = nn.Linear(n_head * d_v, d_o)
        
    def forward(self, q, k, v, mask=None):
        n_head, d_q, d_k, d_v = self.n_head, self.d_k, self.d_k, self.d_v
        batch, n_q, d_q_ = q.size()
        batch, n_k, d_k_ = k.size()
        batch, n_v, d_v_ = v.size()
        
        q = self.fc_q(q) # 1. 单头变多头
        k = self.fc_k(k)
        v = self.fc_v(v)
        
        # 重塑为多头的形状
        # q: [b, l, h*d] -> [b, l, h, d] -> [h, b, l, d] -> [b*h, l, d]
        q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)
        # k, v: [b, l, d] -> [b, h, l, d] -> [b*h, l, d]
        k = k.view(batch, 1, n_k, d_k).repeat(1, n_head, 1, 1).contiguous().view(-1, n_k, d_k) # 8 个头共享一个 k, v
        v = v.view(batch, 1, n_v, d_v).repeat(1, n_head, 1, 1).contiguous().view(-1, n_v, d_v)
        # a = k.view(-1, n_head, n_k, d_k)
        # assert a[0, 0, 0, 0] == a[0, 7, 0, 0]
        
        if mask is not None:
            mask = mask.repeat(n_head, 1, 1)
        
        attn, output = self.attention(q, k, v, mask=mask) # 2. 当成单头注意力求输出
        # attn: [b*h, l_q, l_k]
        
        # output: [b*h, l_q, d_v] -> [b, h, l_q, d_v] -> [h, l_q, b, d_v] -> [b, l_q, h*d_v]
        output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1) # 3. Concat
        
        # [..., h*d_v] -> [..., d_o]
        output = self.fc_o(output) # 4. 仿射变换得到最终输出
        
        return attn, output

In [189]:
if __name__ == "__main__":
    n_q, n_k, n_v = 2, 4, 4
    d_q_, d_k_, d_v_ = 128, 128, 64
    batch = 4
    
    q = torch.randn(batch, n_q, d_q_)
    k = torch.randn(batch, n_k, d_k_)
    v = torch.randn(batch, n_v, d_v_)
    mask = torch.zeros(batch, n_q, n_k).bool()
    
    mha = MyMultiQueryAttention(n_head=8, d_k_=128, d_v_=64, d_k=256, d_v=128, d_o=128)
    attn, output = mha(q, k, v, mask=mask)
    
    print(attn.shape)
    print(output.shape)
    print(attn)
    print(output)

torch.Size([32, 2, 4])
torch.Size([4, 2, 128])
tensor([[[0.2673, 0.1969, 0.1328, 0.4030],
         [0.2837, 0.2374, 0.2473, 0.2316]],

        [[0.1549, 0.2106, 0.3558, 0.2787],
         [0.1554, 0.3733, 0.3000, 0.1713]],

        [[0.1432, 0.4309, 0.2799, 0.1459],
         [0.1951, 0.3352, 0.2062, 0.2635]],

        [[0.2496, 0.1815, 0.3300, 0.2390],
         [0.2416, 0.2436, 0.1758, 0.3390]],

        [[0.2967, 0.2253, 0.2807, 0.1973],
         [0.2050, 0.2747, 0.2827, 0.2376]],

        [[0.2191, 0.2262, 0.1897, 0.3651],
         [0.1095, 0.2302, 0.2499, 0.4104]],

        [[0.1709, 0.2289, 0.2501, 0.3501],
         [0.1842, 0.1808, 0.2313, 0.4037]],

        [[0.2063, 0.3573, 0.2533, 0.1831],
         [0.2624, 0.1467, 0.3159, 0.2750]],

        [[0.2486, 0.1880, 0.3166, 0.2467],
         [0.2750, 0.2199, 0.2350, 0.2701]],

        [[0.3261, 0.2748, 0.2305, 0.1686],
         [0.2672, 0.2331, 0.2529, 0.2469]],

        [[0.2777, 0.0860, 0.3919, 0.2444],
         [0.2631, 0.1802, 0.16