In [None]:
'''
[https://arxiv.org/pdf/2405.04434] Multi-Latent Attention (introduced in DeepSeek-V2)

While the attention computation involves less compute than the MLPs in terms of FLOPs for common sequence lengths, 
many important prompts that users care about are long-sequence, ie. attention handles tens or hundreds of thousands of tokens. 
These might include putting a novel or paper on in context to get information about it, 
a Deep Research agent compiling your report with lots of search results in context, and more. 

The primary bottleneck in these long-sequence workloads is the memory required to store the KV cache, 
which is linear in S, and is roughly O(bsd) where b is batch size, s is seqlen, d is mode's hidden dim. 
MLA is a way to compute attention *in a smaller latent space* which means *we never need to materialize keys and values 
in the computation at all*. The core idea is to just store latent vectors c_kv in the cache, so the cache size now 
goes from [b, nh, s, hd] to now [b, nh, s, d_c_h] where (d_c = d_c_h * nh) << (d_model = d_h * nh). 

Queries are also projected to a small latent space to help save memory during training. Some important
conceptual notes 
    - The paper introduces a lot of matrices re up/down projections. self.to_c_kv "absorbs" most of them in that 
    we don't explicitly need all of those, they are learned when we learn to_c_kv. 
    - Attn itself is computed in latent space, and this is why the full K matrix never appears. 

'''

In [2]:
### minimal implementation, ignoring decoupled RoPE for simplicity 
import torch, torch.nn as nn, torch.nn.functional as F
import math
from tqdm import tqdm 
import time
import matplotlib.pyplot as plt

class MLAProjections(nn.Module): 
    def __init__(self, d_model: int = 512, d_c_total: int = 64): 
        super().__init__()

        self.to_q_lat = nn.Linear(d_model, d_c_total, bias=False)
        self.to_c_kv = nn.Linear(d_model, d_c_total, bias=False)
        self.to_v = nn.Linear(d_c_total, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)


class MLA(nn.Module): 
    def __init__(self, d_model: int = 512, d_c_total: int = 64, n_heads: int = 16, causal: bool = True):
        super().__init__()

        assert d_model % n_heads == 0, "Error: n_heads must divide d_model for \
            each head to be of the same size!"
        assert d_c_total % n_heads == 0, "Error: n_heads must divide d_c_total!"
    
        self.d_model = d_model 
        self.d_c_total = d_c_total 
        self.n_heads = n_heads 
        self.d_h = int(self.d_model/n_heads)
        self.d_c_h = int(d_c_total/n_heads)
        self.causal = causal 
        self.scale = 1.0 / math.sqrt(self.d_c_h)

        self.projections = MLAProjections(d_model, d_c_total)

        # need to have a KV cache since that's where the gains are! 
        # see language-models/kv_cache.ipynb if you haven't already 
        self.clear_cache()

    def clear_cache(self):
        self.register_buffer("c_cache", None, persistent=False)
        self.cache_pos = 0

    def _apply_causal_mask(self, logits: torch.Tensor, s: int, L: int, device: torch.device):
        """Apply causal mask to attention logits"""
        q_idx = torch.arange(s, device=device)
        k_idx = torch.arange(L, device=device)
        mask = (k_idx[None, :] > q_idx[:, None])                      # [s,k]
        logits.masked_fill_(mask[None, None, ...], float('-inf'))

    def forward(self, x: torch.Tensor, decode: bool = False): # [b, s, d_model] -> [b, s, d_model]
        b, s, d = x.shape 
        device = x.device 

        # get q, kv (via c) in latent space 
        q_lat = self.projections.to_q_lat(x) # [b, s, d_c]
        c_kv = self.projections.to_c_kv(x)   # [b, s, d_c]

        # reshape both [b, s, d_c] ->  [b, nh, s, d_c_h]
        q_lat = q_lat.view(b, s, self.n_heads, self.d_c_h).transpose(1, 2) 
        c_kv = c_kv.view(b, s, self.n_heads, self.d_c_h).transpose(1, 2) ## <-- the key new object in MLA compared to vanilla mhsa

        # hit/update cache if decode 
        if decode: 
            if self.cache_pos == 0: # this is the only place we use .cache_pos
                self.c_cache = c_kv
                self.cache_pos += s                   
                k_len = s
            else: 
                c_kv = torch.cat([self.c_cache, c_kv], dim=2) 
                self.c_cache = c_kv 
                self.cache_pos += s
                k_len = self.c_cache.size(2)
        else:
            self.clear_cache()
            k_len = s

        # do attn in latent space 
        attn_logits = torch.einsum('bhqd,bhkd->bhqk', q_lat, c_kv) * self.scale
        
        if self.causal and not decode: 
            self._apply_causal_mask(attn_logits, s, s, device)
            
        attn = F.softmax(attn_logits, dim=-1)

        # project c_kv to values and do A@v 
        v = self.projections.to_v(c_kv.transpose(1, 2).reshape(b, k_len, self.d_c_total))
        v = v.view(b, k_len, self.n_heads, self.d_h).transpose(1, 2)

        # return wo(out)
        out = torch.einsum('bhqk,bhkd->bhqd', attn, v)
        out = out.transpose(1, 2).reshape(b, s, -1)
        
        return self.projections.wo(out)


# sanity check cache and decoding is working correctly
if __name__ == "__main__":
    torch.manual_seed(0)
    mla = MLA(d_model=256, n_heads=4, d_c_total=64)
    x   = torch.randn(2, 10, 256) # b, s, d

    # full‑sequence (masked)
    y_full = mla(x)                              

    # incremental decode
    mla.clear_cache()
    y_inc = torch.cat([mla(x[:, t:t+1, :], decode=True) for t in range(10)], dim=1)

    # they should now be identical
    torch.testing.assert_close(y_full, y_inc, atol=1e-5, rtol=0)
    print("✓ MLA causal: full‑sequence == incremental decode")


✓ MLA causal: full‑sequence == incremental decode


In [5]:
# baseline we'll compare MLA to, ignore this cell it's just defining a baseline 
class VanillaMHSA(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_h = d_model // n_heads
        self.scale = self.d_h ** -0.5
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.clear_cache()

    def clear_cache(self):
        self.k_cache = None
        self.v_cache = None

    def forward(self, x, decode=False):
        b, s, _ = x.shape
        # project q, k, v
        q = self.q_proj(x).view(b, s, self.n_heads, self.d_h).transpose(1, 2)
        k = self.k_proj(x).view(b, s, self.n_heads, self.d_h).transpose(1, 2)
        v = self.v_proj(x).view(b, s, self.n_heads, self.d_h).transpose(1, 2)

        if decode:
            # incremental decode: append new k, v to cache
            if self.k_cache is None:
                self.k_cache = k
                self.v_cache = v
            else:
                self.k_cache = torch.cat([self.k_cache, k], dim=2)
                self.v_cache = torch.cat([self.v_cache, v], dim=2)
            k_all, v_all = self.k_cache, self.v_cache
        else:
            k_all, v_all = k, v

        # scaled dot-product
        attn_scores = torch.einsum('bhqd,bhkd->bhqk', q, k_all) * self.scale
        if decode:
            # at decode time we only need last query row
            attn_scores = attn_scores[:, :, -1:, :]
            attn = F.softmax(attn_scores, dim=-1)
        else:
            # full-sequence causal mask
            mask = torch.tril(torch.ones(s, s, device=x.device)).view(1, 1, s, s)
            attn = F.softmax(attn_scores.masked_fill(mask == 0, float('-inf')), dim=-1)

        out = torch.einsum('bhqk,bhkd->bhqd', attn, v_all)
        out = out.transpose(1, 2).reshape(b, s, self.d_model)
        return self.out_proj(out)


Vanilla MHSA max sequence length: 11920
MLA max sequence length: 14900


In [6]:
# here we see the max seqlen both mhsa and mla can decode to 
# the upshot of MLA is nominally a smaller KV cache, by roughly a factor of d_c/d_model
# so MHSA should OOM earlier than MLA, and we see that indeed it does 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
d_model, n_heads, d_c_total = 2048, 2048//64, 256

mla = MLA(d_model=d_model, n_heads=n_heads, d_c_total=d_c_total).to(device).eval()
vanilla = VanillaMHSA(d_model=d_model, n_heads=n_heads).to(device).eval()

def try_decode(model, seqlen):
    seqlen = int(seqlen)
    try:
        model.clear_cache()
        # create input of the specified sequence length
        x = torch.randn(1, seqlen, d_model, device=device)
        
        # process the entire input at once to build the kv cache
        model(x, decode=False)
        
        # then decode 20 additional tokens
        gen_x = torch.randn(1, 1, d_model, device=device)
        for _ in range(20):
            model(gen_x, decode=True)
            
        return True
    except RuntimeError as e:
        if "out of memory" in str(e):
            return False
        raise e
    except KeyboardInterrupt:
        print("Operation interrupted by user")
        return False

# try increasingly large sequence lengths until oom
seqlen = 1024
vanilla_max_seqlen = 0
mla_max_seqlen = 0
try:
    # test vanilla mhsa
    while True:
        if try_decode(vanilla, int(seqlen)):
            vanilla_max_seqlen = int(seqlen)
            seqlen *= 1.25
        else:
            break
    
    print(f"Vanilla MHSA max sequence length: {vanilla_max_seqlen}")
    
    # test mla
    seqlen = vanilla_max_seqlen
    while True:
        if try_decode(mla, int(seqlen)):
            mla_max_seqlen = int(seqlen)
            seqlen *= 1.25
        else:
            break
    
    print(f"MLA max sequence length: {mla_max_seqlen}")
    
except KeyboardInterrupt:
    print("\nBenchmark interrupted by user")
    if vanilla_max_seqlen > 0:
        print(f"Vanilla MHSA max sequence length: {vanilla_max_seqlen}")
    if mla_max_seqlen > 0:
        print(f"MLA max sequence length: {mla_max_seqlen}")


Vanilla MHSA max sequence length: 11920
MLA max sequence length: 14900
