<a href="https://colab.research.google.com/github/shahabday/DSR-LLMQuantization/blob/main/01_MHA_vs_MQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Head Attention (MHA)

In [None]:
def mem_size(element):
    return element.element_size() * element.nelement() / 1e6

$$
\Large
\text{attention}=\text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right)\cdot V
$$

In [None]:
import torch
import math

N = 16  # mini-batch size
h = 4   # number of heads
S = 99  # number of tokens
L = 99
d_k = 512 # hidden dimensions

torch.manual_seed(42)
K = torch.randn(N, h, L, d_k)
V = torch.randn(N, h, L, d_k)
Q = torch.randn(N, h, S, d_k)

In [None]:
K.shape, V.shape, Q.shape

(torch.Size([16, 4, 99, 512]),
 torch.Size([16, 4, 99, 512]),
 torch.Size([16, 4, 1, 512]))

How big are the inputs?

In [None]:
mem_size(K), mem_size(V), mem_size(Q)

(12.976128, 12.976128, 12.976128)

In [None]:
logits = torch.matmul(Q, K.transpose(2, 3)) # Output shape [N, h, S, L]
softmax_out = torch.softmax(logits / math.sqrt(d_k), dim=-1) # Output shape [N, h, S, L]

attn_out = torch.matmul(softmax_out, V) # Output shape [N, h, S, d_k]

In [None]:
logits.shape, softmax_out.shape, attn_out.shape

(torch.Size([16, 4, 99, 99]),
 torch.Size([16, 4, 99, 99]),
 torch.Size([16, 4, 99, 512]))

How big are the outputs?

In [None]:
mem_size(logits), mem_size(attn_out)

(2.509056, 12.976128)

## KV Caching

Using caching, we use K and V from previous computations, so there's no need to recompute them for the current token.

In [None]:
# Cached K and V values across iterations
# 99 tokens so far
torch.manual_seed(42)
K = torch.randn(N, h, L, d_k)
V = torch.randn(N, h, L, d_k)

# 100th token comes in, projections (K, V, Q) are computed for this token alone
torch.manual_seed(17)
# Single-step QKV values computed during sequence generation
Q_incr = torch.randn(N, h, 1, d_k)
K_incr = torch.randn(N, h, 1, d_k)
V_incr = torch.randn(N, h, 1, d_k)

# Update KV-cache
K = torch.cat([K, K_incr], dim=-2)
V = torch.cat([V, V_incr], dim=-2)

In [None]:
K.shape, V.shape, Q_incr.shape

(torch.Size([16, 4, 100, 512]),
 torch.Size([16, 4, 100, 512]),
 torch.Size([16, 4, 1, 512]))

How big are the inputs? What's the main difference?

In [None]:
mem_size(K), mem_size(V), mem_size(Q_incr)

(13.1072, 13.1072, 0.131072)

In [None]:
# Compute attention (L is sequence length so far)
logits = torch.matmul(Q_incr, K.transpose(2, 3)) # Output shape [N, h, 1, L]
softmax_out = torch.softmax(logits / math.sqrt(d_k), dim=-1) # Output shape [N, h, 1, L]
attn_out = torch.matmul(softmax_out, V) # Output shape [N, h, 1, d_k]

In [None]:
logits.shape, softmax_out.shape, attn_out.shape

(torch.Size([16, 4, 1, 100]),
 torch.Size([16, 4, 1, 100]),
 torch.Size([16, 4, 1, 512]))

How big are the outputs? What's the difference?

In [None]:
mem_size(logits), mem_size(attn_out)

(0.0256, 0.131072)

KV caching does not only save memory, but also computation.

## Benchmarking

In [None]:
# source: https://medium.com/@joaolages/kv-caching-explained-276520203249
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

for use_cache in (True, False):
    times = []
    for _ in range(3):  # measuring 10 generations
        start = time.time()
        model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)
        times.append(time.time() - start)
    print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


with KV caching: 9.613 +- 0.563 seconds


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


without KV caching: 41.968 +- 1.285 seconds


Did you notice the speed up?

# Multi-Query Attention (MQA)

In MQA, we don't have K and V projections for every single token. There's only one K and V that's used to compute attention for every token in the sequence which has its own Q projection.

In [None]:
# Cached K and V values across iterations
torch.manual_seed(42)
# K = torch.randn(N, h, L, d_k)
# V = torch.randn(N, h, L, d_k)
# K_single = K[:, 0, :, :]
# V_single = V[:, 0, :, :]
K_single = torch.randn(N, 1, L, d_k)
V_single = torch.randn(N, 1, L, d_k)

torch.manual_seed(17)
# Single-step QKV values computed during sequence generation
Q_incr = torch.randn(N, h, 1, d_k)

In [None]:
K_single.shape, V_single.shape, Q_incr.shape

(torch.Size([16, 1, 99, 512]),
 torch.Size([16, 1, 99, 512]),
 torch.Size([16, 4, 1, 512]))

How big are the inputs? What changed now?

In [None]:
mem_size(K_single), mem_size(V_single), mem_size(Q_incr)

(3.244032, 3.244032, 0.131072)

In [None]:
# Compute attention (L is sequence length so far)
# NB: K is broadcasted (repeated) out across Q's `h` dimension!
logits = torch.matmul(Q_incr, K_single.transpose(2, 3)) # Output shape [N, h, 1, L]
softmax_out = torch.softmax(logits / math.sqrt(d_k), dim=-1) # Output shape [N, h, 1, L]
# NB: V is broadcasted (repeated) out across softmax_out's `h` dimension!
attn_out = torch.matmul(softmax_out, V_single) # Output shape [N, h, 1, d_k]

In [None]:
logits.shape, softmax_out.shape, attn_out.shape

(torch.Size([16, 4, 1, 99]),
 torch.Size([16, 4, 1, 99]),
 torch.Size([16, 4, 1, 512]))

How big are the outputs? Did anything change?

In [None]:
mem_size(logits), mem_size(attn_out)

(0.025344, 0.131072)