# Self-Attention代码：

In [4]:
import torch
import torch.nn as nn
class Attention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model) 
        self.W_v = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        #调用函数计算注意力分数
        attention_output,attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        return attention_output, attention_weights

    def  scaled_dot_product_attention(self, Q, K, V, mask=None):
        #注意力分数
        scores = torch.matmul(Q, K.transpose(-2,-1)) / torch.sqrt(torch.tensor(self.d_model,dtype=torch.float32)) 

        #应用填充掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) #masked_fill(mask,value)把输入张量中，对应 mask 为 True 的位置，全部替换成 value。

        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

# MHA代码：

In [2]:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model) #Linear（self,in,out)输入维度只传入一个参数是在说明“我对每个向量的 d_model 维特征做线性变换，不管有多少个向量。”
        self.W_k = nn.Linear(d_model, d_model) #Linear是向量处理器，设计哲学是它不处理“整个矩阵的运算”，而是处理“每个向量的变换”。
        self.W_v = nn.Linear(d_model, d_model) #会自动对每一个 d_model 维的向量做相同的线性变换
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):#此处的querykeyvalue都是同一个值“输入向量x”
        batch_size = query.size(0)#.size(0)获取第0维的大小，即batch_size

        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2) #self.W_q(query) 会触发 self.W_q 实例的 __call__ 方法。
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2) #进而调用其 forward 方法。
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2) #对输入 query 的每一个向量进行线性变换，返回一个形状相同的输出张量 Q。

        #调用函数计算注意力分数
        attention_output,attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        #拼接
        #.transpose()：只改“怎么读数据”（逻辑索引），不改“数据在哪”（物理顺序）
        #.contiguous()：把数据复制一份，按新的逻辑顺序重新排列在内存中
        #.view()只能对连续存储的数据进行操作
        concat_attention = attention_output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)

        #最终线性变换
        output = self.W_o(concat_attention)
        return output, attention_weights

    def  scaled_dot_product_attention(self, Q, K, V, mask=None):
        #注意力分数
        scores = torch.matmul(Q, K.transpose(-2,-1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32)) 
        #把 d_k 变成张量，不是因为“不能算”，而是因为“要融入计算图”。

        #应用填充掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) #masked_fill(mask,value)把输入张量中，对应 mask 为 True 的位置，全部替换成 value。

        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output, attention_weights


# GQA代码：

In [3]:
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_kv_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.d_k = d_model // num_heads
        self.head_group_size = num_heads // num_kv_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model // num_heads * num_kv_heads)
        self.W_v = nn.Linear(d_model, d_model // num_heads * num_kv_heads)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        sql_len = query.size(1)

        #计算OKV
        Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k)
        K = self.W_k(key).view(batch_size, seq_len, self.num_kv_heads, self.d_k)
        V = self.W_v(value).view(batch_size, seq_len, self.num_kv_heads, self.d_k)

        #扩展kv
        K = K.unsqueeze(2).expand(-1, -1, self.head_group_size, -1, -1)
        V = V.unsqueeze(2).expand(-1, -1, self.head_group_size, -1, -1)

        #重塑张量
        Q = Q.transpose(1, 2)
        K = K.reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        #计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))

        #填充掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0,-1e9)

        #计算注意力权重
        attention_weights = torch.softmax(scores, dim=-1)
        #计算注意力输出
        output = torch.matmul(attention_weights, V)

        #重塑输出，合并多头
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        #学会多头信息
        output = self.W_o(output)

        return output, attention_weights
