###Multi-Head Latent Attention

Multi-Head Latent Attention (**MLA**) is the key to efficiency at scale. Recent models represent advancement in how they handle the attention mechanism. DeepSeek with 671B parameters (37B activated) needed innovative approach to maintain efficiency without compromising quality.

MLA is built upon Key-Value cache with a new flow for key value computation introducing **Down-projection** projecting the input embedding into a compressed latent space, **Storage** compressing representation in KV cache, and **Up-Projection** to reconstruct the full sized key and value matrices on the fly. With this we reduce the footprint and preserve the model quality.


The **implementation** compresses the KV Cache with MLA, introduces RoPE injecting positional awareness with Rotary Positional Encoding (RoPE) and using MLA and RoPE with a decoupled architecture. RoPE encodes position directly into attention by rotating vectors in complex space preserving the relative distance betwen tokens regardless of the context length.

Additional innovation is incorporated for position representation - the architecture combines MLA with a decoupled positional encoding system. Attention is split into parallel paths, content path (MLA) and positon path (RoPE) to parameter efficiency, training and scaling as each representation is specific.

####**Working with the Latent Matrix**

- The Query Matric calculation is standard, direct projection of input embeddings. $Q = X \cdot W_q$
- The Latent $KV$ Matrix $c_KV$ is the input embeddings projected down into the compressed latent space. The $c_KV$ is what we eventually cache.  $c_KV = X  \cdot W_kv$
- The Key matrix $K$ is no longer a direct projection of $X$, instead an up-projection of the of latent matrix, $c_KV$.  $c_KV \cdot W_uk$. If we make the substitutions we see the full transformation from the original input  ($X \cdot W_kv$) \cdot $W_uk$ $K = c_KV \cdot W_uk = X \cdot W_kv \cdot W_uk$
- The value matrix V is also the up-projection of the same latent matrix $cKV$, and with substitutions we get ($X \cdot W_kv$)$ \cdot W_uv$ - $V = c_KV \cdot W_uv = X \cdot W_kv \cdot W_uv$

The absorption in getting the attention scores with substitutions get us into new definitions.

**Attention Score** = $Q \cdot K^T = X \cdot (W_q \cdot W_k^T) \cdot (X \times W_{kv})^T$

($W_q \cdot W_uk^T$) is fixed at training time and $(X \cdot W_kv)^T$ is what get cached.

**Context Vector Matrix** = Attention Scores ($Q \cdot K^T$) Cached ($X \cdot W_dkv$) Fixed-at-training ($W_uv \cdot W_o$)

Final **output projection layer** W_o: (Attention Weights) $\cdot$ ( ($X \cdot  W_d-kv) \cdot W_uv$) -- **cached latent matrix**. The simplified process to get the attention scores with the original formula $Q \cdot K^T$ effectively becomes:
(Input $X$) $\cdot$ (A fixed, pre-computed matrix) $\cdot$ (The transpose of latent $c_KV$ matrix)

In [None]:
'''  Apply Rotary Positional Encoding, not part of embedding, and
applied to Query and Key vectors '''
class RoPE(nn.Module):
    def __init__(self, head_size, max_seq_len=2048):
      super().__init__()
      theta = 1.0 / (10000 ** torch.arange(0, head_size, 2).float()/head_size)
      self.register_buffer('theta', theta)

      positions = torch.arange(max_seq_len).float().unsqueeze(1)
      frequencies = positions * self.theta.unsqueeze(0)
      self.register_buffer(
              'frequencies_complex',
              torch.polar(torch.ones_like(frequencies), frequencies))

    def forward(self, x):
        seq_len = x.shape[2]
        x_complex = x.float().reshape(*x.shape[:-1], -1, 2)
        x_complex = torch.view_as_complex(x_complex)
        frequencies_complex = self.frequencies_complex[:seq_len, :].\
                                                    unsqueeze(0).unsqueeze(0)
        x_rotated = x_complex * frequencies_complex
        x_rotated = torch.view_as_real(x_rotated)
        x_rotated = x_rotated.flatten(3)
        return x_rotated.type_as(x)

In [None]:
''' The full sota attention mechanism from DeepSeek,
Multi-Head Attention (MLA) with Rotational Positional Encoding '''
class MLAAttention(nn.Module):
      def __init__(self, ds_model, num_heads, d_latent, d_rope, dropout=0.0,
                   max_seq_len=2048):
        super().__init__()
        assert ds_model % num_heads == 0, 'ds model divisible by num_heads'
        self.ds_model = ds_model
        self.num_heads = num_heads
        self.head_size = ds_model // num_heads
        self.d_latent = d_latent
        self.d_rope = d_rope

        self.W_query_content = nn.Linear(ds_model, ds_model)
        self.W_dkv_content = nn.Linear(ds_model, d_latent)
        self.W_uk_content = nn.Linear(d_latent, ds_model)
        self.W_uv_content = nn.Linear(d_latent, ds_model)

        self.W_k_pos = nn.Linear(ds_model, d_rope * num_heads)
        self.W_q_pos = nn.Linear(ds_model, d_rope * num_heads)

        self.rope = RoPE(d_rope, max_seq_len)

        self.W_out_proj = nn.Linear(ds_model, ds_model)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(
            torch.ones(1, 1, max_seq_len, max_seq_len), diagonal=1).bool())

      def forward(self, x):
        batch_size, seq_len, _ = x.shape
        q_c = self.W_query_content(x).view(
            batch_size, seq_len, self.num_heads, self.head_size).transpose(1,2)
        c_kv = self.W_dkv_content(x)
        k_c = self.W_uk_content(c_kv).view(
            batch_size, seq_len, self.num_heads, self.head_size).transpose(1,2)
        v_c = self.W_uv_content(c_kv).view(
            batch_size, seq_len, self.num_heads, self.head_size).transpose(1,2)

        q_r_unrotated = self.W_q_pos(x).view(
            batch_size, seq_len, self.num_heads, self.head_size).transpose(1,2)
        k_r_unrotated = self.W_k_pos(x).view(
            batch_size, seq_len, self.num_heads, self.head_size).transpose(1,2)
        q_r = self.rope(q_r_unrotated)
        k_r = self.rope(k_r_unrotated)

        content_scores = torch.matmul(
            q_c, k_c.transpose(-2, -1)) / (self.head_size **0.5)
        position_scores = torch.matmul(
            q_r, k_r.transpose(-2, -1)) / (self.d_rope **0.5)
        attn_scores = content_scores + position_scores
        attn_scores = attn_scores.masked_fill(
            self.mask[:, :, :seq_len, :seq_len], float('-inf'))
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ v_c).transpose(1, 2).contiguous().\
                            view(batch_size, seq_len, self.ds_model)
        output = self.W_out_proj(context_vector)
        return output

In [None]:
ds_model = 512
num_heads = 8
d_latent = 128
d_rope = 64
batch_size = 4
seq_len = 64

deepseek_attn_layer = MLAAttention(ds_model, num_heads, d_latent, d_rope)
x = torch.randn(batch_size, seq_len, ds_model)
output = deepseek_attn_layer(x)
print('Deep Seek Layer')
print(f'input shape: {x.shape}')
print(f'output shape: {output.shape}')