In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [30]:
@dataclass
class Config:
  d_model: int = 5120
  n_heads: int = 32
  kv_compression_dim: int = 512
  q_compression_dim: int = 512
  rope_dim: int = 64

In [35]:
class MultiHeadLatentAttention(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    self.d_model = config.d_model
    self.n_heads = config.n_heads

    assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"
    self.head_dim = self.d_model // self.n_heads     # d_h (dimension_per_head)

    self.kv_compression_dim = config.kv_compression_dim    # d_c (KV compression dimension)
    self.q_compression_dim = config.q_compression_dim      # d'_c (query compression dimension)
    self.rope_dim = config.rope_dim        # d_R (decoupled RoPE vector dimension per head)

    # 1. Down projection for keys and values
    self.w_DKV = nn.Linear(self.d_model, self.kv_compression_dim)   # d_model -> d_c

    # 2. Up projection for keys
    self.w_UK = nn.Linear(self.kv_compression_dim, self.d_model)    # d_c -> d_model

    # 3. Up projection for values
    self.w_UV = nn.Linear(self.kv_compression_dim, self.d_model)    # d_c -> d_model

    # 4. Decoupled key projection
    self.w_KR = nn.Linear(self.d_model, self.n_heads * self.rope_dim)    # d_model -> nh * d_R

    # 5. Down projection for queries
    self.w_DQ = nn.Linear(self.d_model, self.q_compression_dim)   # d_model -> d'_c

    # 6. Up projection for queries
    self.w_UQ = nn.Linear(self.q_compression_dim, self.d_model)   # d'_c -> d_model

    # 7. Decoupled query projection
    self.w_QR = nn.Linear(self.q_compression_dim, self.n_heads * self.rope_dim)   # d'_c -> nh * d_R

    # 8. Output projection
    self.w_o = nn.Linear(self.d_model, self.d_model)


  def forward(self, x):
    # x: (B, T, d_model)
    B, T, _ = x.shape

    # compute compressed latent for keys and values
    c_KV = self.w_DKV(x)    # (B, T, d_c)

    # up-project c_KV to get keys and values for all heads
    k_C = self.w_UK(c_KV)   # (B, T, d_model)
    value = self.w_UV(c_KV)   # (B, T, d_model)

    # reshape k_C, value to (B, nh, T, head_dim)
    k_C = k_C.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # (B, nh, T, head_dim)
    value = value.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # (B, nh, T, head_dim)

    # compute decoupled positional key component from original input
    k_R = self.w_KR(x)   # (B, T, nh * d_R)
    k_R = k_R.reshape(B, T, self.n_heads, self.rope_dim)   # (B, T, nh, d_R)
    k_R = k_R.transpose(1, 2)    # (B, nh, T, d_R)

    # Apply RoPE to k_R: k_R = ROPE(k_R)   ...Skipping for now.

    # compute compressed latent for queries
    c_q = self.w_DQ(x)    # (B, T, d'_c)

    # up-project q_C to get query components for all heads
    q_C_all = self.w_UQ(c_q)    # (B, T, d_model)

    # reshape
    q_C = q_C_all.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)   # (B, nh, T, hd)

    q_R_all = self.w_QR(c_q)     # (B, T, nh * d_R)
    q_R = q_R_all.view(B, T, self.n_heads, self.rope_dim)   # (B, T, nh, d_R)
    q_R = q_R.transpose(1, 2)    # (B, nh, T, d_R)

    # Apply RoPE to q_R: q_R = ROPE(q_R)   ...Skipping for now.

    query = torch.cat([q_C, q_R], dim=-1)   # (B, nh, T, hd + d_R)
    key = torch.cat([k_C, k_R], dim=-1)   # (B, nh, T, hd + d_R)

    attention_scores = query @ key.transpose(2, 3) / math.sqrt(key.shape[-1])   # (B, nh, T, T)
    attention_scores = attention_scores.softmax(dim=-1)   # (B, nh, T, T)

    z = attention_scores @ value    # (B, nh, T, d_model)

    z = z.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)   # (B, T, d_model)

    return self.w_o(z)    # (B, T, d_model)



In [36]:
config = Config()
mla_attn = MultiHeadLatentAttention(config)
mla_attn

MultiHeadLatentAttention(
  (w_DKV): Linear(in_features=5120, out_features=512, bias=True)
  (w_UK): Linear(in_features=512, out_features=5120, bias=True)
  (w_UV): Linear(in_features=512, out_features=5120, bias=True)
  (w_KR): Linear(in_features=5120, out_features=2048, bias=True)
  (w_DQ): Linear(in_features=5120, out_features=512, bias=True)
  (w_UQ): Linear(in_features=512, out_features=5120, bias=True)
  (w_QR): Linear(in_features=512, out_features=2048, bias=True)
  (w_o): Linear(in_features=5120, out_features=5120, bias=True)
)

In [37]:
BATCH_SIZE = 1
SEQ_LEN = 4096
D_MODEL = config.d_model
X = torch.randn(BATCH_SIZE, SEQ_LEN, D_MODEL)

In [38]:
out = mla_attn(X)
out.shape

torch.Size([1, 4096, 5120])