In [12]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

def apply_rope_embeddings(q, k): # rotate each row in q = [b, s, d] and k
    _, s, d = q.shape
    device = q.device

    # constant from RoPE paper, originall from Transformer paper sinusoidal embeddings
    base = torch.tensor(10_000, device=device) 
    # get positions and freqs, call the joint [s, d] matrix phases
    thetas = torch.exp(torch.log(base) * -2.0 * (torch.arange(d//2, device=device))/d)
    # this line implicity broadcasts [s] * [d//2] -> [s, d//2] which we'll then alternate to get [s, d]
    phases = torch.arange(s, device=device)[:,None] * thetas 

    # sin/cosify phases 
    sin, cos = torch.sin(phases).repeat(1, 2), torch.cos(phases).repeat(1, 2)

    def _flip(v):
        even_v = v[:, 0::2] 
        neg_odd_v = v[:, 1::2] * -1.

        flipped_v = torch.zeros_like(v)
        flipped_v[:, 0::2] = neg_odd_v  
        flipped_v[:, 1::2] = even_v   
        return flipped_v
        # flip so that q is now [-q1, q0, -q3, q2...] 
        # we do this because the multiplication and addition below is mathematically equivalent
        # to applying a 2x2 rotation matrix on every pair of entries in each q_i, k_i which are [d]-vectors
        # but a sparse matmul would be more expensive compute-wise than just the elementwise ops below

    # convolve to simulate sparse matmul and return 
    q = cos * q + sin * _flip(q)
    k = cos * k + sin * _flip(k)

    return q, k


class MHSA(nn.Module): 
    def __init__(self, d=512, max_seqlen=1024, b=16): 
        super().__init__()
        self.wq = nn.Linear(d, d)
        self.wk = nn.Linear(d, d)
        self.wv = nn.Linear(d, d)
        self.wo = nn.Linear(d, d)
        assert d % 64 == 0 
        self.head_dim = 64 
        self.nheads = d//self.head_dim

    def forward(self, x): # x is [b, s, d]
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        b, s, d = q.shape; nh = self.nheads; hd = self.head_dim
        # project to heads, each should be [b, nh, s, hd]
        
        q = q.reshape(b, s, nh, hd).transpose(1, 2) # [b, nh, s, hd]
        k = k.reshape(b, s, nh, hd).transpose(1, 2)
        v = v.reshape(b, s, nh, hd).transpose(1, 2)

        q, k = apply_rope_embeddings(q, k) 
        # at this point, q and k are cached/cache is updated, ie. after rotation 

        A_logits = torch.einsum('bnik,bnjk->bnij', q, k) # [b, nh, s, hd] @ [b, nh, s, hd] -> [b, nh, s, s]
        A = F.softmax(A_logits/(hd**0.5), dim=-1)

        out = torch.einsum('bnij,bnjk->bnik', A, v) # [b, nh, s, s] @ [b, nh, s, hd] -> [b, nh, s, hd]
        out = out.transpose(1, 2).reshape(b, s, d)

        return self.wo(out)


# intuition -> mem/compute (and thus throughput) gains of a factor of (d/c) ~ 10x!
# sequence-dependent objectives are isolated from embed_dim, so we see [d] and [S, c] but never [S, d] 
# like in normal attention

# naive question: why can't we just apply_rope_embeddings in the same way as above in MLA? 
    # why the need for new "decoupled RoPE"? 
    # the reason is because during inference, normally we cache RK ie. the rotated keys 
    # but with MLA, we aren't storing the keys, and in fact the point is to never 
    # materialize a [S, d] matrix during inference -- cache is [S, c] << [S, d]
    # so we can't just "expand up and rotate" since that would defeat the point of never 
    # materializing [S, d] memory in the first place!

    # and even if you did materialize by up projecting (q @ w_d_q) which is a [c]-vector
    # to [d] dimensions using w_u_q which is [c, d], note that the rotation matrix R which is [d, d]
    # does *not* commute with w_u_q, so it would be wrong (ie. not equivalent to R @ q in regular mhsa)

    # so this is why we need "decoupled rope" that rotates only the queries (which )

# Key intuition: the key fact is that *we never materialize -- either in cache or fwd pass -- an [S, d] key or value matrix
# during decoding at all.* We *do attention logit computation "Q@K.T" in latent space of dimension c<<d*
# ie. down-project q_t to dimension [c] then attend over the cache of dim [c, S] to get an S-vector of logits
# so "latent-attention" means *attention over projections of keys/values into a small space"

class MLA(nn.Module): 
    def __init__(self, d=512, c=64, head_dim=64): 
        super().__init__()
        # note that d is the full hidden_dim for the transformer, ie. d = num_heads * head_dim
        # and not head dimension, which is, unsurprisingly, head_dim
        self.wq = nn.Linear(d, d)
        self.wk = nn.Linear(d, d)
        self.wv = nn.Linear(d, d)
        self.wo = nn.Linear(d, d)
        assert d % head_dim == 0 
        self.head_dim = head_dim
        self.nheads = d//self.head_dim
        self.c = c # should have c << d for savings, for us 64 << 512 so approx 10x less KV cache size
        
        # self.cache would be defined if we had a full inference-ready implementation, it would be [b, max_seqlen, c]
        # as opposed to how its usually [b, max_seqlen, d]
        # and we'd use it during decoding as 
        
        # then up project at inference time
        self.w_d_kv = nn.Linear(d, c)
        self.w_d_q = nn.Linear(d, c)

        self.w_u_k = nn.Linear(c, self.head_dim * self.nheads)
        self.w_u_v = nn.Linear(c, self.head_dim * self.nheads)
        self.w_u_q = nn.Linear(c, self.head_dim * self.nheads)

        # at inference, we can precompute the below matrices since our weights are frozen 
        # so that there are no more matmuls then normal attention since up_proj and attn_proj
        # get fused into one transformation (ie. what used to be self.wq(x) becomes self.nwq(x))
        # if we didn't do this, then projecting c -> d -> d becomes extra compute cost 

        # self.nwq = self.wq @ self.w_u_q
        # self.nwk = self.wq @ self.w_u_q
        # self.nwv = self.wq @ self.w_u_q

    def forward(self, x): # x is [b, s, d]
        c_kv = self.w_d_kv(x) # [b, s, d] @ [d, c] -> [b, s, c]
        c_q = self.w_d_q(x) # [b, s, d] @ [d, c] -> [b, s, c]

        # if model.eval() and cache is nonempty (ie. prefill is done),
        # we would update cache = torch.cat([self.cache, v_kv], dim=-1) at this point
        # to grow latent cache during decoding 

        q, k, v = self.w_u_q(c_q), self.w_u_k(c_kv), self.w_u_v(c_kv) 
        # each is [b, s, hd * nh] = [b, s, d] where s=1 in decoding
        
        b, s, d = q.shape; nh = self.nheads; hd = self.head_dim
        # project to heads, each should be [b, nh, s, hd]
        
        q = q.reshape(b, s, nh, hd).transpose(1, 2) # [b, s, d] -> [b, s, nh, hd] -> [b, nh, s, hd]
        k = k.reshape(b, s, nh, hd).transpose(1, 2)
        v = v.reshape(b, s, nh, hd).transpose(1, 2)

        A_logits = torch.einsum('bnik,bnjk->bnij', q, k) # [b, nh, s, hd] @ [b, nh, s, hd] -> [b, nh, s, s]
        A = F.softmax(A_logits/(hd**0.5), dim=-1)

        out = torch.einsum('bnij,bnjk->bnik', A, v) # [b, nh, s, s] @ [b, nh, s, hd] -> [b, nh, s, hd]
        out = out.transpose(1, 2).reshape(b, s, d)

        return self.wo(out)


b, s, d = 16, 128, 256
x = torch.randn(b, s, d)
mhsa = MLA(d=d)
mhsa(x).shape


torch.Size([16, 128, 256])

In [44]:
# mla_causal.py  – Multi‑latent Attention (DeepSeek‑V2, Eq. 9‑13 + causal mask)
from __future__ import annotations 
import math, torch, torch.nn as nn, torch.nn.functional as F
from typing import Optional, Tuple

# ── helpers ────────────────────────────────────────────────────────────────
def precompute_rope(freqs: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs)          # [seq, d/2]
    return freqs.cos(), freqs.sin()

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    x_even, x_odd = x[..., 0::2], x[..., 1::2]
    return torch.cat([x_even * cos - x_odd * sin,
                      x_even * sin + x_odd * cos], dim=-1)



# ── MLA Components ────────────────────────────────────────────────────────
class MLAProjections(nn.Module):
    """Projections for Multi-Latent Attention"""
    def __init__(self, d_model: int, d_c_total: int, n_heads: int, d_r_head: int):
        super().__init__()
        # latent projections ── Eq.(9‑13)
        self.to_q_lat = nn.Linear(d_model, d_c_total, bias=False)   # absorbs W_U^Kᵀ W_U^Q W_D^Q
        self.to_c_kv  = nn.Linear(d_model, d_c_total, bias=False)   # W_D^KV
        self.to_v     = nn.Linear(d_c_total, d_model, bias=False)   # W_U^V
        self.wo       = nn.Linear(d_model, d_model, bias=False)

        # positional slice (decoupled RoPE) ── §2.1.3
        self.to_q_r = nn.Linear(d_model, n_heads * d_r_head, bias=False)
        self.to_k_r = nn.Linear(d_model, d_r_head, bias=False)

class RoPEPositionalEncoding(nn.Module):
    """Rotary Positional Encoding"""
    def __init__(self, d_r_head: int, rope_base: float = 10_000.0):
        super().__init__()
        inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r_head, 2).float() / d_r_head))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
    
    def forward(self, seq_len: int, device: torch.device, 
                start_pos: Optional[int] = None, end_pos: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        cos, sin = precompute_rope(self.inv_freq, seq_len)
        if start_pos is not None and end_pos is not None:
            pos_idx = torch.arange(start_pos, end_pos, device=device)
            return cos[pos_idx], sin[pos_idx]
        return cos, sin

# ── MLA ────────────────────────────────────────────────────────────────────
class MLA(nn.Module):
    r"""Multi‑head Latent Attention (DeepSeek‑V2, §2.1.2–2.1.3) *with causal mask*."""

    def __init__(self,
                 d_model:   int  = 512,
                 n_heads:   int  = 8,
                 d_c_total: int  = 128,
                 d_r_head:  int  = 32,
                 rope_base: float = 10_000.0):
        super().__init__()
        assert d_c_total % n_heads == 0, "d_c_total must divide n_heads"
        assert d_r_head  % 2       == 0, "d_r_head must be even"

        # constants
        self.h      = n_heads
        self.d_c    = d_c_total
        self.d_c_h  = d_c_total // n_heads
        self.d_r    = d_r_head
        self.scale  = 1 / math.sqrt(self.d_c_h + self.d_r)

        # components
        self.projections = MLAProjections(d_model, d_c_total, n_heads, d_r_head)
        self.rope = RoPEPositionalEncoding(d_r_head, rope_base)

        self.clear_cache()

    # ───────── cache helpers ────────────────────────────
    def clear_cache(self):
        self.register_buffer("_cache_c", None, persistent=False)  # latent KV   [b,h,L,d_c_h]
        self.register_buffer("_cache_r", None, persistent=False)  # rot‑keys    [b,L,d_r]
        self._seq_len_cached = 0

    # ───────── forward ─────────────────────────────────
    def forward(self, x: torch.Tensor, decode: bool = False) -> torch.Tensor:
        """
        If `decode=False`  → full‑sequence (training / pre‑fill, causal mask applied).
        If `decode=True`   → x is **one token**; uses & extends KV cache.
        """
        b, s, _ = x.shape
        device  = x.device

        # ── low‑rank projections ─────────────────────────
        c_kv  = self.projections.to_c_kv(x)                              # [b,s,d_c]
        q_lat = self.projections.to_q_lat(x)                             # [b,s,d_c]

        c_kv  = c_kv .view(b, s, self.h, self.d_c_h).transpose(1, 2)  # [b,h,s,d_c_h]
        q_lat = q_lat.view(b, s, self.h, self.d_c_h).transpose(1, 2)  # [b,h,s,d_c_h]

        # ── positional slice + RoPE ───────────────────────
        q_r = self.projections.to_q_r(x).view(b, s, self.h, self.d_r)    # [b,s,h,d_r]
        k_r = self.projections.to_k_r(x)                                 # [b,s,d_r]

        max_len = self._seq_len_cached + s if decode else s
        
        if decode:
            cos_p, sin_p = self.rope(max_len, device, self._seq_len_cached, self._seq_len_cached + s)
        else:
            cos, sin = self.rope(max_len, device)
            cos_p, sin_p = cos[:s], sin[:s]

        cos_q = cos_p[None, :, None, :]                      # [1,s,1,d/2]
        sin_q = sin_p[None, :, None, :]
        q_r   = apply_rope(q_r, cos_q, sin_q)

        cos_k = cos_p[None, :, :]                            # [1,s,d/2]
        sin_k = sin_p[None, :, :]
        k_r   = apply_rope(k_r, cos_k, sin_k)

        # ── update / fetch cache ─────────────────────────
        if decode:
            self._cache_c = c_kv if self._cache_c is None else torch.cat([self._cache_c, c_kv], dim=2)
            self._cache_r = k_r  if self._cache_r is None else torch.cat([self._cache_r, k_r ], dim=1)
            self._seq_len_cached += s
            k_lat = self._cache_c                              # [b,h,L,d_c_h]
            k_r   = self._cache_r                              # [b,L,d_r]
        else:
            k_lat = c_kv                                       # current sequence
            self.clear_cache()                                 # keep cache empty for training

        L = k_lat.shape[2]

        # ── attention logits  (latent + positional) ─────
        content_logits = torch.einsum('bhqd,bhkd->bhqk', q_lat, k_lat)   # [b,h,q,k]
        pos_logits     = torch.einsum('bqhd,bkd->bhqk',  q_r,   k_r)     # [b,h,q,k]
        logits         = (content_logits + pos_logits) * self.scale

        # Apply causal mask when needed
        if not decode:
            self._apply_causal_mask(logits, s, L, device)

        attn = F.softmax(logits, dim=-1)                                   # [b,h,q,k]

        # ── values & output ──────────────────────────────
        v_all = self._compute_values(k_lat, b, L)
        out   = torch.einsum('bhqk,bhkd->bhqd', attn, v_all) \
                    .transpose(1, 2).reshape(b, s, -1)
        return self.projections.wo(out)
    
    def _apply_causal_mask(self, logits: torch.Tensor, s: int, L: int, device: torch.device):
        """Apply causal mask to attention logits"""
        q_idx = torch.arange(s, device=device)
        k_idx = torch.arange(L, device=device)
        mask  = (k_idx[None, :] > q_idx[:, None])                      # [s,k]
        logits.masked_fill_(mask[None, None, ...], float('-inf'))
    
    def _compute_values(self, k_lat: torch.Tensor, b: int, L: int) -> torch.Tensor:
        """Compute values from latent keys"""
        return self.projections.to_v(k_lat.transpose(1, 2).reshape(b, L, self.d_c)) \
                    .view(b, L, self.h, -1).transpose(1, 2)                # [b,h,L,d_h]

# ── smoke‑test: causal full pass ≡ streaming decode ─────────────────────────
if __name__ == "__main__":
    torch.manual_seed(0)
    mla = MLA(d_model=256, n_heads=4, d_c_total=64, d_r_head=32)
    x   = torch.randn(2, 10, 256)

    # full‑sequence (masked)
    y_full = mla(x)                              # [2,10,256]

    # incremental decode
    mla.clear_cache()
    y_inc = torch.cat([mla(x[:, t:t+1, :], decode=True) for t in range(10)], dim=1)

    # they should now be identical
    torch.testing.assert_close(y_full, y_inc, atol=1e-5, rtol=0)
    print("✓ MLA causal: full‑sequence == incremental decode")


✓ MLA causal: full‑sequence == incremental decode


In [None]:
## my shid 
class MLAProjections(nn.Module): 
    def __init__(self, d_model: int, d_c_total: int): 
        super().__init__()
        self.to_q_lat = nn.Linear(d_model, d_c_total, bias=False) # h_t -> q_t
        self.to_c_kv = nn.Linear(d_model, d_c_total, bias=False) # h_t -> c_kv
        self.to_v = nn.Linear(d_c_total, d_model, bias=False) # c_kv -> v_t
        self.wo = nn.Linear(d_model, d_model, bias=False) # usual

class MLA(nn.Module): 
    def __init__(self, d_model: int, d_c_total: int, n_heads: int):
        super().__init__()

        # sanity check 
        assert d_model % n_heads == 0, "Error: d_model must be divisible by n_heads!"
        
        # constants 
        self.d_model = d_model 
        self.n_heads = n_heads
        self.d_head = self.d_model // self.n_heads
        self.d_c_total = d_c_total 
        self.d_c_h = d_c_total // n_heads 
        self.attn_logits_scale = 1/(math.sqrt(self.d_c_h)) # will need to modify when roping

        # subclasses, will need to add rope
        self.projections = MLAProjections(d_model, d_c_total)

        self.clear_cache()

    def clear_cache(self):
        # register buffers and clear them 
        self.register_buffer("_cache_c", None, persistent=False) # latent cache [b, s, nh, d_c]
        self._seq_len_cached = 0 # cache pointer

    def _apply_causal_mask(self, logits: torch.Tensor, q_len: int, k_len: int, device: torch.device):
        # Create causal mask to ensure tokens only attend to previous positions
        mask = torch.ones(q_len, k_len, device=device).triu_(diagonal=1).bool()
        # Apply the mask by setting masked positions to -inf
        logits.masked_fill_(mask[None, None, ...], float('-inf'))
        return logits


    # x is either [b, s, d] in prefill/training or [b, 1, d] in decoding with latent_cache populated 
    def forward(self, x: torch.Tensor, decode: bool = False): 

        b, s, _ = x.shape 
        device = x.device 

        # project inputs to low rank space 
        q_lat = self.projections.to_q_lat(x)
        c_kv = self.projections.to_c_kv(x) # [b, s, d_c]

        # reshape things to [b, s, n_heads, d_c_h] ready for attn 
        q_lat = q_lat.reshape(b, s, self.n_heads, self.d_c_h).transpose(1, 2)
        c_kv = c_kv.reshape(b, s, self.n_heads, self.d_c_h).transpose(1, 2)

        # update cache if decoding 
        if decode:
            c_kv = c_kv if self._cache_c is None else torch.cat([self._cache_c, c_kv], dim=2) # c_kv was [b, 1, d_c] and now is [b, s, d_c]
            self._cache_c = c_kv
            self._seq_len_cached += s
            k_len = c_kv.shape[1]
        else: 
            self.clear_cache()
            k_len = s

        # do attn in latent space 
        attn_logits = torch.einsum('bnij,bnkj->bnik', q_lat, c_kv) # [b, nh, qs, d_c_h] @ [b, nh, ks, d_c_h] -> [b, nh, qs, ks]
        if not decode:
            self._apply_causal_mask(attn_logits, s, s, device)
        attn = F.softmax(attn_logits * self.attn_logits_scale, dim=-1)

        # Reshape c_kv to match expected dimensions for to_v
        # c_kv_flat = c_kv.transpose(1, 2).reshape(b, k_len, self.d_c_total)
        # v = self.projections.to_v(c_kv_flat).reshape(b, k_len, self.n_heads, self.d_head).transpose(1, 2)
        v = self.projections.to_v(c_kv.transpose(1, 2).reshape(b, s, self.d_c_total)) \
                    .view(b, s, self.n_heads, -1).transpose(1, 2)                # [b,h,L,d_h]
        
        print(v.shape, 'about to einsum')
        out = torch.einsum('bnij,bnjd->bnid', attn, v) # [b, nh, qs, ks] @ [b, nh, ks, d_h] -> [b, nh, s, d_h]
        out = out.transpose(1, 2).reshape(b, s, -1)

        return self.projections.wo(out) # output [b, s, d_model] as usual 

if __name__ == "__main__":
    torch.manual_seed(0)
    mla = MLA(d_model=256, n_heads=4, d_c_total=64)
    x   = torch.randn(2, 10, 256) # b, s, d

    # full‑sequence (masked)
    y_full = mla(x)                              # [2,10,256]

    # incremental decode
    mla.clear_cache()
    y_inc = torch.cat([mla(x[:, t:t+1, :], decode=True) for t in range(10)], dim=1)

    # they should now be identical
    torch.testing.assert_close(y_full, y_inc, atol=1e-5, rtol=0)
    print("✓ MLA causal: full‑sequence == incremental decode")


torch.Size([2, 4, 10, 64]) about to einsum
torch.Size([2, 4, 1, 64]) about to einsum


RuntimeError: shape '[2, 1, 64]' is invalid for input of size 256

In [None]:
# dont understand reln to equations
# what is decode doing
# dont get decoupled rope 