In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Optional, Tuple
import math
from dataclasses import dataclass 

In [2]:
# 前置代码
class DeepseekV2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


# 位置编码
class DeepseekV2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # 较小索引位置对应较低频率
        # 较大的索引位置有较高的频率
        
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
        )
        self.max_seq_len_cached = None

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )

        freqs = torch.outer(t, self.inv_freq.to(t.device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def apply_rotary_pos_emb_v2(q: torch.Tensor, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    return q_embed

In [3]:
@dataclass
class DeepseekConfig:
    hidden_size: int
    num_heads: int
    max_position_embeddings: int
    rope_theta: float

    attention_dropout: float

    q_lora_rank: int # 得到低维度q的低秩变换矩阵
    qk_rope_head_dim: int # qv向量所需 rope 位置编码向量的维度

    kv_lora_rank: int # 得到低维度kv的低秩变换矩阵 

    v_head_dim: int # v向量升秩后的维度
    qk_nope_head_dim: int # qk向量升秩后的维度  
    attention_bias: bool 
    
    training: bool = True




In [4]:
class MLA(nn.Module):
    '''包含矩阵吸收'''
    def __init__(self, config: DeepseekConfig):
        super(MLA, self).__init__()
        self.training = config.training

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.v_head_dim = config.v_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.qk_rope_head_dim = config.qk_rope_head_dim


        self.out_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False
        )

        # 压缩
        self.q_lora_rank = config.q_lora_rank  
        self.kv_lora_rank = config.kv_lora_rank

        
        self.q_down_proj = nn.Linear(
            self.hidden_size,
            self.q_lora_rank,
            bias=config.attention_bias
        )
        
        self.q_down_norm = DeepseekV2RMSNorm(self.q_lora_rank)

        # k在降维时得到位置编码
        self.kv_down_proj = nn.Linear(
            self.hidden_size,
            self.kv_lora_rank + config.qk_rope_head_dim,
            bias=config.attention_bias
        )

    

        self.kv_down_norm = DeepseekV2RMSNorm(self.kv_lora_rank + self.qk_rope_head_dim)

        # 升维
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim # 

        # q向量在升维时得到位置编码
        self.q_up_proj = nn.Linear(
            self.q_lora_rank,
            self.num_heads * self.q_head_dim,
            bias= False, # 
        )

        # qk升维用的是同一个低秩向量，这里同时进行kv的升维，所以映射后的维度要包括q和v的
        self.kv_up_proj = nn.Linear(
            self.kv_lora_rank,
            self.num_heads * (config.qk_nope_head_dim + config.v_head_dim),
            bias=False
        ) 

        self.rotary_emb = DeepseekV2RotaryEmbedding(
            config.qk_rope_head_dim,
            config.max_position_embeddings,
            config.rope_theta
        )

    def forward(self, hidden_states: torch.Tensor, position_ids, attention_mask=None):
        b, s, d = hidden_states.shape
        # q part
        q = self.q_down_proj(hidden_states)
        q = self.q_down_norm(q)
        q = self.q_up_proj(q) # num_heads * ( nope_dim + rope_dim )

        # [b, s, num_heads * (nope+rope)] -> [b, num_heads, s, nope+rope]
        q = q.view(b, s, self.num_heads, self.q_head_dim).transpose(1, 2)
        # split q to q_nope and q_rope
        q_nope, q_rope = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        
        # k v part
        c_kv_and_rope = self.kv_down_proj(hidden_states)
        c_kv_and_rope = self.kv_down_norm(c_kv_and_rope)
        c_kv, k_rope = c_kv_and_rope.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

        # [b, s, 1, rope_dim] -> [b, 1, s, rope_dim]
        k_rope = k_rope.view(b, s, 1, self.qk_rope_head_dim).transpose(1, 2)

        # 从 kv_up_proj 中分离出 W_UK 和 W_UV
        kv_b_proj = self.kv_up_proj.weight.view(
            self.num_heads, -1, self.kv_lora_rank
        )

        q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]
        out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]


        cos, sin = self.rotary_emb(q_rope, seq_len=s)
        q_rope = apply_rotary_pos_emb_v2(
            q_rope, cos, sin, position_ids
        )


        # W_UK被q_nope吸收
        q_nope = torch.einsum('hdc, bhqd->bhqc', q_absorb, q_nope)


        attn_weights = torch.matmul(q_rope, k_rope.transpose(-1, -2)) + torch.einsum('bhqc, blc->bhql', q_nope, c_kv)
        attn_weights = attn_weights / math.sqrt(self.q_head_dim)


        if attention_mask is not None:
            attn_weights = torch.masked_fill(
                attn_weights,
                attention_mask == 0,
                float('-inf')
            )
        
        attn_weights = F.softmax(attn_weights, dim=-1).to(hidden_states.dtype)

        attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)

        o_  = torch.einsum('bhql,blc->bhqc', attn_weights, c_kv) # (4)
        o   = torch.einsum('bhqc,hdc->bhqd', o_, out_absorb)  # (5)
        u   = torch.einsum('hdD,bhqd->bqD', self.out_proj.weight.view(self.num_heads, self.v_head_dim, -1), o)     # (6)

        return u, attn_weights

In [5]:
def test():
    config = DeepseekConfig(
        hidden_size=7168,
        num_heads=16,
        max_position_embeddings=1024,
        rope_theta=128000,
        attention_dropout=0.1,
        q_lora_rank=1536,
        qk_rope_head_dim=64,
        kv_lora_rank=512,
        
        v_head_dim=128,
        qk_nope_head_dim=128,
        attention_bias=False,

    )

    mla = MLA(config)
    x = torch.randn(2, 1024, 7168)
    position_ids = torch.arange(
        config.max_position_embeddings,
    ).unsqueeze(0).expand(
        x.size(0), -1
    ) 

    attn_output, attn_weights = mla(x, position_ids=position_ids)
    print(attn_output.shape, attn_weights.shape)

test()

torch.Size([2, 1024, 7168]) torch.Size([2, 16, 1024, 1024])
