In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LlamaMHA(nn.Module):
    def __init__(self, dim, num_heads, head_dim=None):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = head_dim if head_dim is not None else dim // num_heads
        # 投影矩阵
        self.q_proj = nn.Linear(dim, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, self.num_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, self.num_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=False)
        self.scale = 1.0 / math.sqrt(self.head_dim)
    def forward(self, x, attention_mask=None, cache=None):
        batch_size, seq_len, _ = x.shape
        # 计算查询、键、值
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # 使用KV缓存(用于推理加速)
        if cache is not None:
            past_k, past_v = cache
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
            cache = (k, v)  
        # 计算注意力分数
        attn_scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        # 应用注意力掩码
        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask
        # 应用softmax获取注意力权重
        attn_weights = F.softmax(attn_scores, dim=-1)
        # 计算输出
        output = torch.matmul(attn_weights, v)  # [batch_size, num_heads, seq_len, head_dim]
        # 重塑输出并进行最终投影
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        output = self.o_proj(output)
        if cache is not None:
            return output, cache
        return output

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
    def __init__(self, head ,head_dim, dim):
        super().__init__()
        self.head = head
        self.head_dim = head_dim
        self.dim = dim
        self.proj_q = nn.Linear(dim, head * head_dim, bias=False)
        self.proj_k = nn.Linear(dim, head * head_dim, bias=False)
        self.proj_v = nn.Linear(head * head_dim, dim, bias=False)
        self.proj_o = nn.Linear(head * head_dim, head * head_dim, bias=False)
        self.scaling = self.head_dim**-0.5
    def forward(self, hidden_states):
        bs, seq_len, dim = hidden_states.shape[0],hidden_states.shape[1],hidden_states.shape[2]
        q = self.proj_q(hidden_states).view(bs,seq_len,self.head,-1).transpose(1,2)
        k = self.proj_k(hidden_states).view(bs,seq_len,self.head,-1).transpose(1,2)
        v = self.proj_v(hidden_states).view(bs,seq_len,self.head,-1).transpose(1,2)
        output1 =  (q.transpose(2,3) @ k) * self.scaling
        output2 = F.softmax(output1, dim=-1)
        output3 = output2 @ v
        output4 = self.proj_o(output3.transpose(1,2).contiguous().reshape(bs,seq_len,-1))
        return output4
attn = Attention(8,128,1024)
attn(torch.randn(8,128,1024))