In [12]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

def apply_rope_embeddings(q, k): # rotate each row in q = [b, s, d] and k
    _, s, d = q.shape
    device = q.device

    # constant from RoPE paper, originall from Transformer paper sinusoidal embeddings
    base = torch.tensor(10_000, device=device) 
    # get positions and freqs, call the joint [s, d] matrix phases
    thetas = torch.exp(torch.log(base) * -2.0 * (torch.arange(d//2, device=device))/d)
    # this line implicity broadcasts [s] * [d//2] -> [s, d//2] which we'll then alternate to get [s, d]
    phases = torch.arange(s, device=device)[:,None] * thetas 

    # sin/cosify phases 
    sin, cos = torch.sin(phases).repeat(1, 2), torch.cos(phases).repeat(1, 2)

    def _flip(v):
        even_v = v[:, 0::2] 
        neg_odd_v = v[:, 1::2] * -1.

        flipped_v = torch.zeros_like(v)
        flipped_v[:, 0::2] = neg_odd_v  
        flipped_v[:, 1::2] = even_v   
        return flipped_v
        # flip so that q is now [-q1, q0, -q3, q2...] 
        # we do this because the multiplication and addition below is mathematically equivalent
        # to applying a 2x2 rotation matrix on every pair of entries in each q_i, k_i which are [d]-vectors
        # but a sparse matmul would be more expensive compute-wise than just the elementwise ops below

    # convolve to simulate sparse matmul and return 
    q = cos * q + sin * _flip(q)
    k = cos * k + sin * _flip(k)

    return q, k


class MHSA(nn.Module): 
    def __init__(self, d=512, max_seqlen=1024, b=16): 
        super().__init__()
        self.wq = nn.Linear(d, d)
        self.wk = nn.Linear(d, d)
        self.wv = nn.Linear(d, d)
        self.wo = nn.Linear(d, d)
        assert d % 64 == 0 
        self.head_dim = 64 
        self.nheads = d//self.head_dim

    def forward(self, x): # x is [b, s, d]
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        b, s, d = q.shape; nh = self.nheads; hd = self.head_dim
        # project to heads, each should be [b, nh, s, hd]
        
        q = q.reshape(b, s, nh, hd).transpose(1, 2) # [b, nh, s, hd]
        k = k.reshape(b, s, nh, hd).transpose(1, 2)
        v = v.reshape(b, s, nh, hd).transpose(1, 2)

        q, k = apply_rope_embeddings(q, k) 
        # at this point, q and k are cached/cache is updated, ie. after rotation 

        A_logits = torch.einsum('bnik,bnjk->bnij', q, k) # [b, nh, s, hd] @ [b, nh, s, hd] -> [b, nh, s, s]
        A = F.softmax(A_logits/(hd**0.5), dim=-1)

        out = torch.einsum('bnij,bnjk->bnik', A, v) # [b, nh, s, s] @ [b, nh, s, hd] -> [b, nh, s, hd]
        out = out.transpose(1, 2).reshape(b, s, d)

        return self.wo(out)


# intuition -> mem/compute (and thus throughput) gains of a factor of (d/c) ~ 10x!
# sequence-dependent objectives are isolated from embed_dim, so we see [d] and [S, c] but never [S, d] 
# like in normal attention

# naive question: why can't we just apply_rope_embeddings in the same way as above in MLA? 
    # why the need for new "decoupled RoPE"? 
    # the reason is because during inference, normally we cache RK ie. the rotated keys 
    # but with MLA, we aren't storing the keys, and in fact the point is to never 
    # materialize a [S, d] matrix during inference -- cache is [S, c] << [S, d]
    # so we can't just "expand up and rotate" since that would defeat the point of never 
    # materializing [S, d] memory in the first place!

    # and even if you did materialize by up projecting (q @ w_d_q) which is a [c]-vector
    # to [d] dimensions using w_u_q which is [c, d], note that the rotation matrix R which is [d, d]
    # does *not* commute with w_u_q, so it would be wrong (ie. not equivalent to R @ q in regular mhsa)

    # so this is why we need "decoupled rope" that rotates only the queries (which )

# Key intuition: the key fact is that *we never materialize -- either in cache or fwd pass -- an [S, d] key or value matrix
# during decoding at all.* We *do attention logit computation "Q@K.T" in latent space of dimension c<<d*
# ie. down-project q_t to dimension [c] then attend over the cache of dim [c, S] to get an S-vector of logits
# so "latent-attention" means *attention over projections of keys/values into a small space"

class MLA(nn.Module): 
    def __init__(self, d=512, c=64, head_dim=64): 
        super().__init__()
        # note that d is the full hidden_dim for the transformer, ie. d = num_heads * head_dim
        # and not head dimension, which is, unsurprisingly, head_dim
        self.wq = nn.Linear(d, d)
        self.wk = nn.Linear(d, d)
        self.wv = nn.Linear(d, d)
        self.wo = nn.Linear(d, d)
        assert d % head_dim == 0 
        self.head_dim = head_dim
        self.nheads = d//self.head_dim
        self.c = c # should have c << d for savings, for us 64 << 512 so approx 10x less KV cache size
        
        # self.cache would be defined if we had a full inference-ready implementation, it would be [b, max_seqlen, c]
        # as opposed to how its usually [b, max_seqlen, d]
        # and we'd use it during decoding as 
        
        # then up project at inference time
        self.w_d_kv = nn.Linear(d, c)
        self.w_d_q = nn.Linear(d, c)

        self.w_u_k = nn.Linear(c, self.head_dim * self.nheads)
        self.w_u_v = nn.Linear(c, self.head_dim * self.nheads)
        self.w_u_q = nn.Linear(c, self.head_dim * self.nheads)

        # at inference, we can precompute the below matrices since our weights are frozen 
        # so that there are no more matmuls then normal attention since up_proj and attn_proj
        # get fused into one transformation (ie. what used to be self.wq(x) becomes self.nwq(x))
        # if we didn't do this, then projecting c -> d -> d becomes extra compute cost 

        # self.nwq = self.wq @ self.w_u_q
        # self.nwk = self.wq @ self.w_u_q
        # self.nwv = self.wq @ self.w_u_q

    def forward(self, x): # x is [b, s, d]
        c_kv = self.w_d_kv(x) # [b, s, d] @ [d, c] -> [b, s, c]
        c_q = self.w_d_q(x) # [b, s, d] @ [d, c] -> [b, s, c]

        # if model.eval() and cache is nonempty (ie. prefill is done),
        # we would update cache = torch.cat([self.cache, v_kv], dim=-1) at this point
        # to grow latent cache during decoding 

        q, k, v = self.w_u_q(c_q), self.w_u_k(c_kv), self.w_u_v(c_kv) 
        # each is [b, s, hd * nh] = [b, s, d] where s=1 in decoding
        
        b, s, d = q.shape; nh = self.nheads; hd = self.head_dim
        # project to heads, each should be [b, nh, s, hd]
        
        q = q.reshape(b, s, nh, hd).transpose(1, 2) # [b, s, d] -> [b, s, nh, hd] -> [b, nh, s, hd]
        k = k.reshape(b, s, nh, hd).transpose(1, 2)
        v = v.reshape(b, s, nh, hd).transpose(1, 2)

        A_logits = torch.einsum('bnik,bnjk->bnij', q, k) # [b, nh, s, hd] @ [b, nh, s, hd] -> [b, nh, s, s]
        A = F.softmax(A_logits/(hd**0.5), dim=-1)

        out = torch.einsum('bnij,bnjk->bnik', A, v) # [b, nh, s, s] @ [b, nh, s, hd] -> [b, nh, s, hd]
        out = out.transpose(1, 2).reshape(b, s, d)

        return self.wo(out)


b, s, d = 16, 128, 256
x = torch.randn(b, s, d)
mhsa = MLA(d=d)
mhsa(x).shape


torch.Size([16, 128, 256])

In [13]:
import torch, torch.nn as nn, torch.nn.functional as F

# ---------- helper: 2×2 block rotation (RoPE) ----------
def rope_rotate(x):
    """
    x: [..., L] where L is even and represents (cos, sin) pairs
    returns rotated tensor of the same shape
    """
    cos, sin = x[..., 0::2], x[..., 1::2]
    x1 = torch.stack([-sin, cos], dim=-1).reshape_as(x)
    return x1

def apply_rope(x, pos):
    """
    x : [B, T, H, d_r]     (d_r even)
    pos: [T] or [B, T]
    Rotates the last dim in‑place.
    """
    B, T, H, d_r = x.shape
    device = x.device
    base = 10_000.0

    idx = torch.arange(d_r // 2, device=device)
    theta = base ** (-2.0 * idx / d_r)                         # [d_r/2]
    phase = pos.to(device).unsqueeze(-1) * theta               # [B?,T,d_r/2]
    cos, sin = torch.cos(phase), torch.sin(phase)
    rot = torch.stack([cos, sin], dim=-1).repeat_interleave(2, dim=-1)
    x_rot = rot * x + rope_rotate(rot * x)      # element‑wise
    return x_rot
# ---------------------------------------------------------


class MLA(nn.Module):
    """
    Multi‑head Latent Attention with **decoupled RoPE**.

    Args
    ----
    d_model   : full hidden width (d)
    head_dim  : per‑head width (d_h)
    d_c       : latent width  (c  << d)
    d_r       : width of the RoPE channel per head (even, ≤ d_h)
    """
    def __init__(self, d_model: int, head_dim: int = 64,
                 d_c: int = 512, d_r: int = 32):
        super().__init__()
        assert d_model % head_dim == 0, "d_model must be H * head_dim"
        assert d_r % 2 == 0 and d_r <= head_dim

        self.d, self.hd, self.H, self.c, self.dr = (
            d_model, head_dim, d_model // head_dim, d_c, d_r
        )
        d_content = head_dim - d_r                   # un‑rotated slice per head
        self.dc_r = d_r * self.H                     # total rotated width
        self.dc_c = d_content * self.H               # total content width

        # --------- low‑rank projections ---------
        self.W_D_kv = nn.Linear(self.d, self.c, bias=False)
        self.W_D_q  = nn.Linear(self.d, self.c, bias=False)

        # up‑proj split into "rot" and "content" parts
        self.W_U_k_rot = nn.Linear(self.c, self.dc_r, bias=False)
        self.W_U_k_con = nn.Linear(self.c, self.dc_c, bias=False)
        self.W_U_v     = nn.Linear(self.c, self.d,    bias=False)

        self.W_U_q_rot = nn.Linear(self.c, self.dc_r, bias=False)
        self.W_U_q_con = nn.Linear(self.c, self.dc_c, bias=False)

        self.out_proj  = nn.Linear(self.d, self.d, bias=False)

        # ------------- cache -------------
        self.register_buffer("cache_latent", None, persistent=False)
        self.register_buffer("cache_pos",    None, persistent=False)

    # -------- attention kernel in latent space --------
    def _latent_attention(self, q_lat, C_lat):
        """
        q_lat : [B, H, c]      current query already premultiplied by W_U_k*
        C_lat : [S, c]         latent cache of keys (shared across heads)
        returns context_lat : [B, H, c]
        """
        # scores  [B,H,S]
        scores = torch.matmul(q_lat, C_lat.t())          # [B,H,S]
        attn   = F.softmax(scores / (self.c ** 0.5), dim=-1)

        # context in latent       [B,H,c]
        ctx_lat = torch.matmul(attn, C_lat)              # [B,H,c]
        return ctx_lat

    # ------------- main forward -------------
    def forward(self, x, *, decode=False):
        """
        x      : [B,T,d]  (T may be 1 during decoding)
        decode : if True → one‑token step with latent cache
        returns y : [B,T,d]
        """
        B, T, _ = x.shape
        device  = x.device

        # ---- 1. down‑project ----
        c_kv = self.W_D_kv(x)     # [B,T,c]
        c_q  = self.W_D_q (x)     # [B,T,c]

        # ---- 2. split up‑projection (rot vs content) ----
        q_rot = self.W_U_q_rot(c_q)   # [B,T,dc_r]
        q_con = self.W_U_q_con(c_q)   # [B,T,dc_c]

        k_rot = self.W_U_k_rot(c_kv)  # [B,T,dc_r]
        k_con = self.W_U_k_con(c_kv)  # [B,T,dc_c]
        v_all = self.W_U_v    (c_kv)  # [B,T,d]

        # ---- 3. reshape to heads ----
        def split_heads(t, per_head):
            return t.view(B, T, self.H, per_head).transpose(1, 2)  # [B,H,T,per_head]

        q_rot = split_heads(q_rot, self.dr)
        k_rot = split_heads(k_rot, self.dr)
        q_con = split_heads(q_con, self.hd - self.dr)
        k_con = split_heads(k_con, self.hd - self.dr)
        v     = split_heads(v_all, self.hd)

        # ---- 4. apply RoPE only on the small channel ----
        pos = (torch.arange(T, device=device)
               if not decode else self.cache_pos[-1:] + 1)   # simple counter
        q_rot = apply_rope(q_rot, pos)   # each: [B,H,T,dr]
        k_rot = apply_rope(k_rot, pos)

        # ---- 5. concatenate rotated + content ----
        q = torch.cat([q_rot, q_con], dim=-1)   # [B,H,T,d_h]
        k = torch.cat([k_rot, k_con], dim=-1)

        # ------------------------------------------------------------------
        if not decode:                       # ---------- TRAIN MODE ----------
            # materialise full K,V for efficiency
            scores = torch.einsum('bhid,bhjd->bhij', q, k)        # [B,H,T,T]
            attn   = F.softmax(scores / (self.hd ** 0.5), dim=-1)
            ctx    = torch.einsum('bhij,bhjd->bhid', attn, v)     # [B,H,T,d_h]
            y = ctx.transpose(1, 2).reshape(B, T, self.d)
            return self.out_proj(y)

        # ------------------------------------------------------------------
        else:                                 # ---------- DECODE MODE ----------
            assert T == 1, "decode expects a single new token"
            # initialise cache buffers if empty
            if self.cache_latent is None:
                self.cache_latent = c_kv[:, -1].detach().clone()  # [1,c]
                self.cache_pos    = torch.tensor([-1], device=device, dtype=torch.long)
            # append new latent & pos
            self.cache_latent = torch.cat([self.cache_latent, c_kv[:, -1]], dim=0)    # [S,c]
            self.cache_pos    = torch.cat([self.cache_pos,    pos],         dim=0)    # [S]

            # ---- operate purely in latent space ----
            Wk = torch.cat([self.W_U_k_rot.weight,  # [dc_r,c]
                            self.W_U_k_con.weight], dim=0)        # [d_h ,c]

            # pre‑multiply query once per head
            q_lat = torch.einsum('bhid,dc->bhc', q, Wk)           # [B,H,c]

            ctx_lat = self._latent_attention(q_lat, self.cache_latent)     # [B,H,c]
            # up‑project to d
            y = torch.einsum('bhc,cd->bhd', ctx_lat,
                             torch.cat([self.W_U_v.weight], dim=0))        # [B,H,d_h]
            y = y.transpose(1, 2).reshape(B, 1, self.d)
            return self.out_proj(y)


In [14]:
B, S, d = 16, 128, 512
x = torch.randn(B, S, d)

# TRAIN / PREFILL
mla = MLA(d_model=d, head_dim=64, d_c=512, d_r=32)
y = mla(x, decode=False)        # full‑sequence forward

# AUTOREGRESSIVE DECODE (one token at a time)
mla.eval()
token = torch.randn(1, 1, d)
out1  = mla(token, decode=True)  # step 1
token = torch.randn(1, 1, d)
out2  = mla(token, decode=True)  # step 2


RuntimeError: The size of tensor a (4) must match the size of tensor b (32) at non-singleton dimension 3