# KV Cache with GPT2

In [1]:
import torch
import torch.nn as nn
import tiktoken
from tqdm import tqdm
import time
import numpy as np


In [2]:
# GeLU
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))

In [3]:
# FFN
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
                nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
                GELU(),
                nn.Linear(cfg["emb_dim"] * 4, cfg["emb_dim"])
            )

    def forward(self, x):
        return self.layers(x)

In [4]:
class KVCache:
    def __init__(self, max_batch_size, max_seq_len, n_kv_heads, head_dim):
        self.max_seq_len = max_seq_len
        self.cache_k = torch.zeros((max_batch_size, max_seq_len, n_kv_heads, head_dim))# .to(device)
        self.cache_v = torch.zeros((max_batch_size, max_seq_len, n_kv_heads, head_dim))# .to(device)

    def update(self, batch_size, start_pos, xk, xv):
        if self.cache_k.device != xk.device:
            self.cache_k = self.cache_k.to(xk.device)
            self.cache_v = self.cache_v.to(xv.device)
        self.cache_k[:batch_size, start_pos :start_pos + xk.size(1)] = xk
        self.cache_v[:batch_size, start_pos :start_pos + xv.size(1)] = xv

    def get(self, batch_size, start_pos, seq_len):
        keys = self.cache_k[:batch_size,  :start_pos + seq_len]
        values = self.cache_v[:batch_size, :start_pos + seq_len]
        return keys, values

In [5]:
# MHA
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, use_kvcache=False):
        super().__init__()
        self.d_out = d_out
        assert d_out % num_heads == 0, "d_out is not divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
             torch.tril(torch.ones(context_length, context_length))
        )
        if use_kvcache:
            print("Using KV Cache")
        self.kv_cache = KVCache(1, context_length, num_heads, d_out // num_heads) if use_kvcache else None
        self.prefilled = False
        self.context_length = context_length
        self.num_tokens = 0

    def forward_using_kvcache(self, x, ):

        # print(x)
        b, num_tokens, d_in = x.shape
        queries = self.W_query(x) # (b, 1, d_in) @ (d_in, d_out) -> (b, 1, d_out)
        keys = self.W_key(x)      # (b, 1, d_in) @ (d_in, d_out) -> (b, 1, d_out)
        values = self.W_value(x)  # (b, 1, d_in) @ (d_in, d_out) -> (b, 1, d_out)

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, c, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        
        k_cache, v_cache = self.kv_cache.get(b, 0, self.num_tokens)
        

        
        keys = torch.cat((k_cache[:, -self.context_length+1:], keys), dim=1)
        keys = keys.transpose(1,2) # (b, num_heads, c, head_dim)
        
        values = torch.cat((v_cache[:, -self.context_length+1:], values), dim=1)
        values = values.transpose(1,2) # (b, num_heads, c, head_dim)

        attn_scores = queries @ keys.transpose(-1, -2) # (b, num_heads, 1, head_dim) @ (b, num_heads, head_dim, c) -> (b, num_heads, 1, c)
        attn_scores = attn_scores / (keys.shape[-1]  ** 0.5)

        attn_weights = torch.softmax(attn_scores, dim=-1) # (b, num_heads, 1, c)

        context_vec = attn_weights @ values # (b, num_heads,1, c) x (b, num_heads, c, head_dim) -> (b, num_heads, 1, head_dim)
        context_vec = context_vec.transpose(1,2) # (b, 1, num_heads, head_dim)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        keys = keys.transpose(1,2) # (b, num_tokens, num_heads, head_dim)
        values = values.transpose(1,2) # (b, num_tokens, num_heads, head_dim)
        
        self.kv_cache.update(1, 0, keys, values)
        self.num_tokens = keys.shape[1]
        return context_vec


    def forward(self, x):
        if self.prefilled and self.kv_cache:
            return self.forward_using_kvcache(x)

        b, num_tokens, d_in = x.shape
        queries = self.W_query(x) # (N, C, d_out)
        keys = self.W_key(x) # (N, C, d_out)
        values = self.W_value(x) # (N, C, d_out)

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, c, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2) # (b, num_heads, c, head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2) # (b, num_heads, c, head_dim)

        
        attn_scores = queries @ keys.transpose(-1, -2) # (b, num_heads, c, head_dim) @ (b, num_heads, head_dim, c) -> (b, num_heads, c, c)
        attn_scores.masked_fill_(self.mask[:num_tokens, :num_tokens] == 0, float('-inf'))
        attn_scores = attn_scores / (keys.shape[-1] ** 0.5)

        attn_weights = torch.softmax(attn_scores, dim=-1) # (b, num_heads, c, c)
        
        context_vec = attn_weights @ values # (b, num_heads,c, head_dim)
        context_vec = context_vec.transpose(1,2) # (b, c, num_heads, head_dim)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        if self.kv_cache is not None:
            self.kv_cache.update(1, 0, keys.transpose(1,2), values.transpose(1,2)) # prefill the cache
            self.prefilled = True

        self.num_tokens = num_tokens
        return context_vec


In [6]:
# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(d_in=cfg["emb_dim"],
                                      d_out=cfg["emb_dim"],
                                      context_length=cfg["context_length"],
                                      dropout=cfg["drop_rate"],
                                     num_heads=cfg["n_heads"],
                                     qkv_bias=cfg["qkv_bias"],
                                     use_kvcache=cfg["use_kvcache"])
        self.ff = FeedForward(cfg)
        self.norm1 = nn.LayerNorm(cfg['emb_dim'])
        self.norm2 = nn.LayerNorm(cfg['emb_dim'])
        self.drop_resid = nn.Dropout(cfg['drop_rate'])

    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.att(x,)
        x = self.drop_resid(x)
        x = x + residual
        
        residual = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + residual
        return x


In [7]:
# GPT Class
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])

        self.drop_emb = nn.Dropout(cfg['drop_rate'])

        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        self.final_norm = nn.LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
        #  self.tok_emb.weight = self.out_head.weight # Weight tying

    def forward(self, in_idx, pos_ids):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(pos_ids)
        
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        
        return logits

## Setting up tokenizer

In [8]:
tokenizer = tiktoken.get_encoding('gpt2')

In [9]:
def generate_text_simple(model, idx, max_new_tokens, context_size, use_kvcache=False):
    if not use_kvcache:
        for i in range(max_new_tokens):
            idx_cond = idx[:, -context_size:]
            with torch.no_grad():
                pos_id = torch.arange(idx_cond.shape[-1], device=idx_cond.device)
                logits = model(idx_cond, pos_id)

            logits = logits[:, -1, :]
            probas = torch.softmax(logits, dim=-1)
            idx_next = torch.argmax(probas, dim=-1, keepdim=True)
            idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)
    else:
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            pos_id = torch.arange(idx_cond.shape[-1], device=idx_cond.device)
            logits = model(idx_cond, pos_id)

        logits = logits[:, -1, :]
        probas = torch.softmax(logits, dim=-1)
        idx_next = torch.argmax(probas, dim=-1, keepdim=True)
        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)
        for i in range(max_new_tokens-1):
            idx_cond = idx[:, -context_size:]
            with torch.no_grad():
                pos_id = torch.arange(idx_cond.shape[-1], device=idx_cond.device)
                logits = model(idx_cond[:,-1:], pos_id[-1:])

            logits = logits[:, -1, :]
            probas = torch.softmax(logits, dim=-1)
            idx_next = torch.argmax(probas, dim=-1, keepdim=True)
            idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)
        
        for n,m in model.named_modules():
            if m.__class__.__name__ == "MultiHeadAttention":
                m.prefilled = False
    return idx

In [10]:

def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) # remove batch dimension
    return tokenizer.decode(flat.tolist())

## Model Configs

In [11]:
DEBUG_GPT_CONFIG_124M = {
    "vocab_size": 50257,  # Vocabulary size
    "context_length": 100,      # Context length
    "emb_dim": 4,       # Embedding dimension
    "n_heads": 2,        # Number of attention heads
    "n_layers": 1,       # Number of layers
    "drop_rate": 0.1,     # Dropout rate
    "qkv_bias": False,     # Query-Key-Value bias
    "use_kvcache": False,  # Enable Key-Value cache during inference
}


GPT_CONFIG_124M = {
    "vocab_size": 50257,  # Vocabulary size
    "context_length": 1024,      # Context length
    "emb_dim": 768,       # Embedding dimension
    "n_heads": 12,        # Number of attention heads
    "n_layers": 12,       # Number of layers
    "drop_rate": 0.1,     # Dropout rate
    "qkv_bias": False,     # Query-Key-Value bias
    "use_kvcache": False,  # Enable Key-Value cache during inference
}

In [12]:
def setup_model(GPTModel, config, device):
    model = GPTModel(config).to(device)
    print(f"Using Device: {device}")
    print(f"KV Cache Enabled: {config['use_kvcache']}")
    total_params = (sum([param.numel() for param in model.parameters()]))
    total_params_gpt2 =  total_params - sum(p.numel() for p in model.out_head.parameters())
    print(f"Number of trainable parameters considering weight tying: {total_params_gpt2:,}")
    return model

In [13]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == "cuda":
    print(torch.cuda.get_device_name(0))

Tesla V100-SXM2-32GB


## Debug Mode

## Baseline

In [14]:
torch.manual_seed(123)
model = setup_model(GPTModel, DEBUG_GPT_CONFIG_124M, device)
model = model.eval()

Using Device: cuda
KV Cache Enabled: False
Number of trainable parameters considering weight tying: 201,668


In [15]:

text = "What is KV caching?"
enc_text = text_to_token_ids(text, tokenizer)
idx = enc_text.to(device)
for _ in range(99):
    idx_cond = idx[:, -DEBUG_GPT_CONFIG_124M["context_length"]:]
    with torch.no_grad():
        pos_id = torch.arange(idx_cond.shape[-1], device=idx_cond.device)
        logits = model(idx_cond, pos_id)

        logits = logits[:, -1:, :]
        probs = torch.softmax(logits, dim=-1)
        next_idx = torch.argmax(probs, dim=-1)

        idx = torch.cat([idx, next_idx], dim=-1)
# print("idx: ", idx)
    
baseline_idx = idx
print(idx)
print(token_ids_to_text(idx, tokenizer))


tensor([[ 2061,   318,   509,    53, 40918,    30, 29406, 18396, 19475, 17925,
         17925,  8979, 29406, 18396, 22769, 46648, 46648,  6998,  5698, 24090,
         46648, 33004, 46648, 25028, 17925, 17925, 25028,  9891, 25891, 25891,
          6998, 17050,   721,  5698, 17050, 25028, 46648, 25028, 17925, 46538,
         25028, 19475, 17925, 17925,  1410, 17638, 46648, 33004, 18396,  6998,
         37321, 17925, 17925, 46648, 16557, 22772, 17925, 41729, 25562, 18396,
         25623, 17925, 19475, 17925,  4579, 46648, 33004, 25028, 19475,  2463,
         17477,  1426,  5698, 17050, 17050,  8979, 17925, 17050,  5698, 25028,
         45166, 33004, 17050, 17050, 12276, 46648,  7658, 46648,  5698, 18396,
         16493, 17925, 38586, 25562,  7658, 12276, 10249, 18396, 17925, 17050,
          8979, 17925, 42342, 17050,  8979]], device='cuda:0')
What is KV caching? mish dude critically Harm HarmFile mish dudecharacterreetingsreetingsRS guidegemreetingsStrongreetings Herman Harm Harm Herman 

## Using KV Cache

In [16]:
torch.manual_seed(123)

DEBUG_GPT_CONFIG_124M["use_kvcache"] = True
model = setup_model(GPTModel, DEBUG_GPT_CONFIG_124M, device)
model = model.eval()

Using KV Cache
Using Device: cuda
KV Cache Enabled: True
Number of trainable parameters considering weight tying: 201,668


In [17]:
text = "What is KV caching?"
enc_text = text_to_token_ids(text, tokenizer)
idx = enc_text.to(device)
idx_cond = idx[:, -DEBUG_GPT_CONFIG_124M["context_length"]:]
with torch.no_grad():
    print(idx_cond)
    pos_ids = torch.arange(idx_cond.shape[-1], device=idx_cond.device)
    logits = model(idx_cond, pos_ids)
    
    logits = logits[:, -1:, :]
    probs = torch.softmax(logits, dim=-1)
    next_idx = torch.argmax(probs, dim=-1)
    
    idx = torch.cat([idx, next_idx], dim=-1)

for _ in range(98):
    idx_cond = idx[:, -DEBUG_GPT_CONFIG_124M["context_length"]:]
    with torch.no_grad():
        
        pos_ids = torch.arange(idx_cond.shape[-1], device=idx_cond.device)
        logits = model(idx_cond[:,-1:], pos_ids[-1:])

        logits = logits[:, -1:, :]
        probs = torch.softmax(logits, dim=-1)
        next_idx = torch.argmax(probs, dim=-1)

        idx = torch.cat([idx, next_idx], dim=-1)

using_kvcache_idx = idx
print("idx: ", idx)
print(token_ids_to_text(idx, tokenizer))

tensor([[ 2061,   318,   509,    53, 40918,    30]], device='cuda:0')
idx:  tensor([[ 2061,   318,   509,    53, 40918,    30, 29406, 18396, 19475, 17925,
         17925,  8979, 29406, 18396, 22769, 46648, 46648,  6998,  5698, 24090,
         46648, 33004, 46648, 25028, 17925, 17925, 25028,  9891, 25891, 25891,
          6998, 17050,   721,  5698, 17050, 25028, 46648, 25028, 17925, 46538,
         25028, 19475, 17925, 17925,  1410, 17638, 46648, 33004, 18396,  6998,
         37321, 17925, 17925, 46648, 16557, 22772, 17925, 41729, 25562, 18396,
         25623, 17925, 19475, 17925,  4579, 46648, 33004, 25028, 19475,  2463,
         17477,  1426,  5698, 17050, 17050,  8979, 17925, 17050,  5698, 25028,
         45166, 33004, 17050, 17050, 12276, 46648,  7658, 46648,  5698, 18396,
         16493, 17925, 38586, 25562,  7658, 12276, 10249, 18396, 17925, 17050,
          8979, 17925, 42342, 17050,  8979]], device='cuda:0')
What is KV caching? mish dude critically Harm HarmFile mish dudecharact

In [18]:
print("KV Cache matches Baseline: ", torch.all(using_kvcache_idx == baseline_idx).item())

KV Cache matches Baseline:  True


## Benchmarking

In [19]:
def benchmark():
    torch.manual_seed(123)
    model = GPTModel(GPT_CONFIG_124M)
    print(f"KV Cache Enabled: {GPT_CONFIG_124M['use_kvcache']}")

    total_params = (sum([param.numel() for param in model.parameters()]))
    total_params_gpt2 =  total_params - sum(p.numel() for p in model.out_head.parameters())
    print(f"Number of trainable parameters considering weight tying: {total_params_gpt2:,}")

    tokenizer = tiktoken.get_encoding('gpt2')
    text = "What is KV caching?"
    encoded_text = text_to_token_ids(text, tokenizer)

    model = model.eval()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Current Device: {device}")

    model = model.to(device)
    encoded_text = encoded_text.to(device)

    times = []
    for _ in tqdm(range(10)):
        start = time.time()
        _ = generate_text_simple(model, encoded_text, max_new_tokens=1000, context_size=GPT_CONFIG_124M['context_length'], use_kvcache=GPT_CONFIG_124M['use_kvcache'])
        times.append((time.time() - start))
    print(f"{round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")


    for n,m in model.named_modules():
        if m.__class__.__name__ == "MultiHeadAttention":
            m.prefilled = False

    ids = generate_text_simple(model, encoded_text, max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'], use_kvcache=GPT_CONFIG_124M['use_kvcache'])
    print(ids)
    print(token_ids_to_text(ids, tokenizer))



### Benchmarking Baseline

In [20]:
benchmark()

KV Cache Enabled: False
Number of trainable parameters considering weight tying: 124,412,160
Current Device: cuda


100%|██████████| 10/10 [03:09<00:00, 18.97s/it]


18.967 +- 0.134 seconds
tensor([[ 2061,   318,   509,    53, 40918,    30, 42435,  3686, 42043, 37759,
         23338, 11760,  8219, 50042, 12132, 46106, 42954, 48984, 42924, 28790,
         49285, 22055, 13968, 39926, 40144, 43959, 39492, 15209, 22104,   137,
         20272, 47586,  5255, 22437, 17404, 47565]], device='cuda:0')
What is KV caching? depletionurg Rooms Fabricن bandstenancewatching painful Maurit Rai Jah".[ pleasingCHAativity embod bribes KiddKENpeed Nort Enhanced�arus counselling decredoes Creative peach


### Benchmarking KV Cache

In [21]:
print("Benchmarking with KV Cache")
GPT_CONFIG_124M['use_kvcache'] = True
benchmark()

Benchmarking with KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
Using KV Cache
KV Cache Enabled: True
Number of trainable parameters considering weight tying: 124,412,160
Current Device: cuda


100%|██████████| 10/10 [01:38<00:00,  9.88s/it]


9.877 +- 0.371 seconds
tensor([[ 2061,   318,   509,    53, 40918,    30, 42435,  3686, 42043, 37759,
         23338, 11760,  8219, 50042, 12132, 46106, 42954, 48984, 42924, 28790,
         49285, 22055, 13968, 39926, 40144, 43959, 39492, 15209, 22104,   137,
         20272, 47586,  5255, 22437, 17404, 47565]], device='cuda:0')
What is KV caching? depletionurg Rooms Fabricن bandstenancewatching painful Maurit Rai Jah".[ pleasingCHAativity embod bribes KiddKENpeed Nort Enhanced�arus counselling decredoes Creative peach


#### Summary
1. Getting roughly 2x speed-up on GPT2 124M model.
2. The new generated token needs to have its position embedding correctly set.
3. Absolute Position embedding is not good for KV Cache. After reaching the max_seq_length, KV cache cannot be used
because position embeddings in the KV Cache conflict with the new input.