###Inference with KV Cache

KV Cache is a fundamental optimization technique that addresses performance in deploying LLMs like DeepSeek.

In autoregressive generation each new token attention computation across previous tokens in the sequence without optimization would lead to quadratically increasing computation time, redundancy in computation and prohibitive memory and computational costs for practical applications.

The Key-value cache solves this dramatically reducing the computations. It is the foundational technique for addressing inference in transformer models like DeepSeek or Llama.

With key-value cache improving inference, DeepSeek further optimizes with advanced attention mechanism to handle massive scale efficiently

**KV Cache**: key and Value projections can be computed once and reused.


```
Algorithm
  Initial computation -compute Q, K, V
  Storage -save K, V tensors in cache
  Subsequent tokens -
      compute Q, K, V only for new tokens
      retrieve previous K, V from cache
      concatenate the new K, V with new cached ones
      store the expanded K, V in cache
Complexity from O(n^2) to O(n) per tojen, n being the sequence length
```




**Multi-Head Attention** (MQA): **MQA** uses a single shared K, V projection across all attention heads, and maintains a separate Q projections for each head.

Key variable here is the number of attention heads as the total memory scales linearly with the number of heads. DeepSeek-v3 (671b) HAS 128 attention heads, the MQA reduces the KV cache size by a factor of 128, from 400GB down to just over 3GB. While MQA is a solution seen as memory-first approach, fundamentally it compromises the core strength of multi-head design. The primary efficiency gain comes from the reduced size of the Key and Value caches as we cache tensors of shape (batch_size, num_heads, seq_len, head_size), reducing drastically the memory.

**- KV Cache Size** in MQA: $l * b * 1 * h * s * 2*2$

In [None]:
'''
In the Multi-Query Layer Attention the query projection maps to full model
dimension. Key, value projections map to just a single head dimension.
Use the repeat function to duplicate the single key and value for all query.
'''
class MQA(torch.nn.Module):
    def __init__(self, ds_model, num_heads, dropout=0.0):
        super().__init__()
        assert ds_model % num_heads == 0, 'ds model divisible by num_heads'
        self.num_heads = num_heads
        self.head_size = ds_model // num_heads

        ''' the Query projection remains the same as the standard MHA '''
        self.W_query_proj = nn.Linear(ds_model, ds_model)
        ''' the Key and Value projections are now single, shared linear layers,
        projecting down to the dimension of a single head (head_size) '''
        self.W_key_proj = nn.Linear(ds_model, self.head_size)
        self.W_value_proj = nn.Linear(ds_model, self.head_size)
        self.W_out_proj = nn.Linear(ds_model, ds_model)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer('mask', torch.triu(torch.ones(1, 1, 1024, 1024)))

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        q = self.W_query_proj(x).view(batch_size, seq_len,
                              self.num_heads, self.head_size).transpose(1, 2)
        k = self.W_key_proj(x).view(batch_size, seq_len, 1,
                              self.head_size).transpose(1, 2)
        v = self.W_value_proj(x).view(batch_size, seq_len, 1,
                              self.head_size).transpose(1, 2)

        ''' the single Key and Value tensors are repeated or broadcast
        to match the number of query heads '''
        k = k.repeat(1, self.num_heads, 1, 1)
        v = v.repeat(1, self.num_heads, 1, 1)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (
                                self.head_size**.5) #(q @ k.transpose(-2, -1))

        attn_scores = attn_scores.masked_fill(
            self.mask[:, :, :seq_len, :seq_len] == 0, float('-inf'))
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vector = (attn_weights @ v).transpose(1, 2).contiguous().view(
            batch_size, seq_len, -1)
        output = self.W_out_proj(context_vector)
        return output

ds_model = 512
num_heads = 8
batch_size = 4
seq_len = 64

mqa_layer = MQA(ds_model, num_heads)
x = torch.randn(batch_size, seq_len, ds_model)
output = mqa_layer(x)
print(x.shape)
print(output.shape)

**Grouped-Query Attention** (GQA); **GQA** groups the attention heads into clusters, 4-8 for 32 heads, with each group sharing the same K, V projections, and queries remaining separate for each head.

The trade-off sacrificing model expressivity for memory efficiency is not ideal and led to seek a more balanced approach, a technique that could offer substantial memory savings without dismantling the power of multi-head design. It is a pragmatic compromise between MHA and MQA.

The core idea of Grouped-Query Attention is that instead of forcing all attention heads to share the same Key and Value matrices, we create goups of attention heads and only share the Keys and Values within these groups.

**- KV Cache Size in GQA**: $l * b * g * h * s * 2*2$

In [None]:
''' In the Group-Query Attention, the Query projection maps to the full
model dimension. the key and value projections map to num_groups*head_size.
Use repeat interleave to match each key and value group with its corresponding
query heads. '''
class GQA(torch.nn.Module):
    def __init__(self, ds_model, num_heads, num_groups, dropout=0.0,
                 max_seq_len: int=0):
        super().__init__()
        assert ds_model % num_heads == 0, 'ds model divisible by num_heads'
        assert num_heads % num_groups == 0, 'num_heads divisible by num_groups'

        self.ds_model = ds_model
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_size = ds_model // num_heads

        self.W_query_proj = nn.Linear(ds_model, ds_model)
        ''' Instead of creating a single projection (head_size-number of heads)
         we create num_groups projections '''
        self.W_key_proj = nn.Linear(ds_model, self.head_size)
        self.W_value_proj = nn.Linear(ds_model, self.head_size)
        self.W_out_proj = nn.Linear(ds_model, ds_model)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer('mask', torch.triu(torch.ones(1, 1, 1024, 1024)))

    def _get_causal_mask(self, seq_len, device):
            if self.causal_mask is not None and self.causal_mask.size[-1] >=seq_len:
                return self.causal_mask[:, :, :seq_len, :seq_len]
            return torch.triu(
                torch.ones(1, 1, seq_len, seq_len), diagonal=1).to(device)

    def _register_mask_buffer(self, max_seq_len):
            if max_seq_len >0:
                mask = torch.triu(torch.ones(
                    1, 1, max_seq_len, max_seq_len), dtype=torch.bool).to(device)
                self.register_buffer('causal_mask', mask, persistence=False)
            else:
                self.causal_mask = None

    def forward(self, x):
            batch_size, seq_len, _ = x.shape()

            q = self.W_query_proj(x).view(
              batch_size, seq_len, self.num_groups, self.head_size).transpose(1, 2)
            ''' the input is projected and reshaped into
            num_groups distinct Key and Value groups '''
            k = self.W_key_proj(x).view(
              batch_size, seq_len, self.num_groups, self.head_size).transpose(1, 2)
            v = self.W_value_proj(x).view(
              batch_size, seq_len, self.num_groups, self.head_size).transpose(1, 2)

            heads_per_group = self.num_heads // self.num_groups
            ''' repeat_interleave broadcasts the K/V groups to query heads,
            each of the num_groups of Keys and Values is shared across head
            per group queries '''
            k = k.repeat_interleave(heads_per_group, dim=2)
            v = v.repeat_interleave(heads_per_group, dim=2)

            attn_scores = torch.matmul(q, k.transpose(-2, -1))

            causal_mask = self._get_causal_mask(seq_len, x.device)

            attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf'))
            attn_weights = torch.softmax(attn_scores, dim=-1)
            attn_weights = self.dropout(attn_weights)
            context = (attn_weights @ v).transpose(1, 2).contiguous().view(
                batch_size, seq_len, self.ds_model)
            return self.W_out_proj(context)

In [None]:
ds_model = 512
num_heads = 32
num_groups = 4
batch_size = 4
seq_len = 64

''' By changing the num_groups, we can move seamlessly from MQA like behavior
(num_groups=1) to MHA-like behavior (num_groups=head_size, number-of-heads) '''
gqa_layer = GQA(ds_model, num_heads, num_groups)
x = torch.randn(batch_size, seq_len, ds_model)
output = mqa_layer(x)
print(x.shape)
print(output.shape)