# **Notebook 6.2: Understanding KV Cache for Efficient Transformer Inference 🚀**  

## **Introduction 📚**  

Welcome to **Notebook 6.2**, where we dive deep into **KV (Key-Value) Caching**, a crucial optimization technique for making transformer-based models more efficient during inference. 🎉

In **Notebook 6.1**, we explored **decoding strategies** like greedy search, beam search, and top-k sampling while running inference on a transformer model. However, we noticed a key challenge: **as sequence length increases, inference slows down significantly** due to repeated attention computations.  

This is where **KV Caching** comes to the rescue! 🚀 Instead of recomputing attention keys and values for every token in the sequence, KV caching **stores and reuses** previously computed states—leading to **massive speedups** in autoregressive decoding (like in GPT models).  

![Decoding Strategies Overview](images/kv.jpg)  


### **What’s Inside? 🔍**  

1️⃣ **Reviewing Standard Inference (from Notebook 6.1) ⏳**  
   - A quick recap of how transformers generate text **without KV caching**.  
   - Understanding why inference **becomes slower** as the sequence grows.  

2️⃣ **How KV Caching Works: Storing and Reusing Attention States 📦**  
   - We'll break down **how transformers compute attention** and **where KV caching fits in**.  
   - You'll see how **storing past keys and values** helps speed up token generation.  

3️⃣ **Implementing KV Cache in a Transformer Decoder ⚡**  
   - We'll modify our model to **store past keys & values** in a cache.  
   - Instead of recomputing everything, the model will **only process new tokens** efficiently.  

4️⃣ **Slicing and Updating KV Cache: Hands-on Exploration 🔬**  
   - Understanding how to **slice, update, and retrieve** keys/values from the cache.  
   - We’ll visualize **tensor slicing** and its role in maintaining an efficient cache.  

5️⃣ **Benchmarking Speed: With and Without KV Caching 🚀**  
   - We’ll compare inference speeds **with and without KV caching** to see the real impact.  
   - Expect **significant improvements**, especially for long sequences!  

---  

### **Why This Notebook Matters 💡**  

KV caching is one of the most important optimizations for **deploying transformers in real-time applications**. By the end of this notebook, you'll:  

✅ Understand **why inference slows down** in transformers without caching.  
✅ Learn how **KV caching reduces redundant computations**.  
✅ Implement **a transformer with KV caching** step-by-step.  
✅ Benchmark and **see massive speed improvements** for text generation.  

🚀 **Let’s unlock faster inference with KV Cache!** 🎯

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM 
torch.manual_seed(0) # For reproducibility



  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x78c5f7fc7050>

You may recalls that in the previous notebook, we delve into the decoding strategies like greedy search, beam search, and top-k sampling. We also noticed that as the sequence length increases, the inference slows down significantly due to repeated attention computations. This is where KV caching comes to the rescue! Instead of recomputing attention keys and values for every token in the sequence, KV caching stores and reuses previously computed states, leading to massive speedups in autoregressive decoding (like in GPT models).🚀

Before we delve into KV caching, let's quickly review how transformers generate text during inference without caching. 

In [3]:
class Sampler:
    def __init__(self , model_name : str ='gpt2-medium') -> None:

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to("cpu").to(self.device)

    def encode(self, text):
        return self.tokenizer.encode(text, return_tensors='pt').to(self.device)

    def decode(self, ids):
        return self.tokenizer.decode(ids)

    def get_next_token_prob(self, input_ids: torch.Tensor):
        with torch.no_grad():
            logits = self.model(input_ids=input_ids).logits
        logits = logits[0, -1, :]
        return logits
    
class GreedySampler(Sampler):
    def __call__(self, prompt, max_new_tokens=10):
        predictions = []
        result = prompt
        # generate until max_len
        for i in range(max_new_tokens):
            
            print(f"step {i} input: {result}")
            input_ids = self.encode(result)
            next_token_probs = self.get_next_token_prob(input_ids=input_ids)
            
            # choose the token with the highest probability
            id = torch.argmax(next_token_probs, dim=-1).item()
            # convert to token and add new token to text
            result += self.decode(id)
            
            predictions.append(next_token_probs[id].item())

        return result


This code defines a text generation pipeline using a causal language model (like GPT-2). The `Sampler` class handles tokenization, encoding, decoding, and extracting next-token probabilities. The `GreedySampler` class extends `Sampler` to generate text using **greedy decoding**, where at each step, it picks the most likely next token. It continues this process for a specified number of steps (`max_new_tokens`), appending each selected token to the prompt to form a generated sequence. 

In [4]:
gen = GreedySampler() 
prompt = "Large Language Models are a type of AI model that"
print(gen(prompt, max_new_tokens=10))

2025-02-17 23:35:52.260101: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-17 23:35:52.729196: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739824553.189824   94028 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739824553.363533   94028 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-17 23:35:54.731814: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appro

step 0 input: Large Language Models are a type of AI model that
step 1 input: Large Language Models are a type of AI model that can
step 2 input: Large Language Models are a type of AI model that can be
step 3 input: Large Language Models are a type of AI model that can be used
step 4 input: Large Language Models are a type of AI model that can be used to
step 5 input: Large Language Models are a type of AI model that can be used to model
step 6 input: Large Language Models are a type of AI model that can be used to model language
step 7 input: Large Language Models are a type of AI model that can be used to model language.
step 8 input: Large Language Models are a type of AI model that can be used to model language. They
step 9 input: Large Language Models are a type of AI model that can be used to model language. They are
Large Language Models are a type of AI model that can be used to model language. They are based


### 🚀 The Power of KV Caching in Efficient Inference  

Do you see the problem that KV caching can solve?  

As the number of input tokens grows during inference, the computational cost (FLOPs – Floating Point Operations) **increases significantly**. This is because each new token requires recomputing attention scores over all previous tokens.  

**KV caching** solves this problem by **storing** the hidden representations of previously computed key-value pairs. Instead of recomputing them for every new token, the model reuses cached values—**reducing redundant computations and speeding up inference!** ⚡

Now lets take a look at the original attention mechanism in transformers to understand how KV caching fits in.
Here we looking at a single head instead of multiple heads for simplicity.

In [5]:
# Define model hyperparameters
embed_size = 768   # Size of the embedding vector for each token
block_size = 64    # Maximum sequence length for attention
head_size = 64     # Size of each attention head
dropout = 0.1      # Dropout rate to prevent overfitting

# Define a single attention head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size

        # Linear layers for key, query, and value transformations
        self.key = nn.Linear(embed_size, head_size, bias=False)
        self.query = nn.Linear(embed_size, head_size, bias=False)
        self.value = nn.Linear(embed_size, head_size, bias=False)

        # Lower triangular matrix for causal masking (prevents attention to future tokens)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape  # Batch size, sequence length, embedding size

        # Compute key, query, and value projections
        k = self.key(x)    # (B, T, head_size)
        q = self.query(x)  # (B, T, head_size)
        v = self.value(x)  # (B, T, head_size)

        # Compute attention scores (scaled dot product attention)
        wei = q @ k.transpose(2, 1) / self.head_size**0.5  # (B, T, T)

        # Apply causal masking to prevent attending to future tokens
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))

        # Apply softmax to get attention weights
        wei = F.softmax(wei, dim=2)  # (B, T, T)

        # Apply dropout to attention weights
        wei = self.dropout(wei)

        # Compute the final weighted sum of values
        out = wei @ v  # (B, T, head_size)

        return out


## 🚀 Attention Head & KV Caching  

### 🔹 Computing Attention Scores  
- **Key (K), Query (Q), and Value (V)** are generated via linear layers with shape **(B, T, C)**, where **C** is the head size.  
- The **dot product** of **Q** and **K** produces a weight matrix **(B, T, T)**, capturing token relationships—**higher scores indicate stronger relevance**.  
- To prevent excessively large values that hinder optimization, we apply scaling:  

  \[
  \frac{1}{\sqrt{\text{head\_size}}}
  \]

  followed by **softmax** to normalize attention weights.  

### 🔹 Enforcing Causality  
- A **lower triangular mask** ensures each token **attends only to previous tokens**, preserving **auto-regressive generation** in models like GPT.  

#### 🔍 Illustration of Auto-Regressive Attention  
![Attention Mechanism](images/atten1.gif)  

---

## 🏎️ Optimizing Inference with KV Caching  

### 🔹 The Problem: Redundant Computation  
In vanilla transformers, every token recomputes **Key-Value (KV) pairs** at every step, leading to inefficiencies in long sequences.  

### 🔹 KV Caching: The Fix  
1️⃣ Instead of processing tokens **incrementally**, we process **one token at a time**. Previously, **all past tokens were needed** to regenerate Value tokens. With KV caching, we store **past K, V pairs**, allowing us to compute attention efficiently.  

2️⃣ **Masking is no longer required.** Since we now pass **a single query token**, the attention matrix **(QKᵀ)** reduces from **(B, T, T)** to **(B, 1, T)**, eliminating redundant masking operations.  

KV caching significantly speeds up inference, making transformers more scalable. Let's dive into its implementation! 🚀




In [6]:
# Define a single attention head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size

        # Linear layers for key, query, and value transformations
        self.key = nn.Linear(embed_size, head_size, bias=False)
        self.query = nn.Linear(embed_size, head_size, bias=False)
        self.value = nn.Linear(embed_size, head_size, bias=False)

        # Lower triangular matrix for causal masking (prevents attention to future tokens)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

        # Initialize KV cache (Added initialization)
        self.k_cache = None  # (Added initialization)
        self.v_cache = None  # (Added initialization)
        self.cache_index = 0  # (Added initialization)

    def forward(self, x):
        B, T, C = x.shape  # Batch size, sequence length, embedding size

        # Compute key, query, and value projections
        k = self.key(x)    # (B, T, head_size)
        q = self.query(x)  # (B, T, head_size)
        v = self.value(x)  # (B, T, head_size)

        # Initialize the KV cache if it is None (Added check and initialization)
        if self.k_cache is None or self.v_cache is None:
            self.k_cache = torch.zeros(B, block_size, self.head_size, device=x.device)  # (Added initialization)
            self.v_cache = torch.zeros(B, block_size, self.head_size, device=x.device)  # (Added initialization)
            self.cache_index = 0  # (Added initialization)

        # Update the KV cache in-place (Added cache update logic)
        if self.cache_index + T <= block_size:
            # If there is space in the cache, add the new keys and values
            self.k_cache[:, self.cache_index:self.cache_index + T, :] = k  # (Cache update)
            self.v_cache[:, self.cache_index:self.cache_index + T, :] = v  # (Cache update)
        else:
            # If the cache is full, shift the old tokens and add new ones
            shift = self.cache_index + T - block_size  # (Added shift calculation)
            self.k_cache[:, :-shift, :] = self.k_cache[:, shift:, :].clone()  # (Cache shift)
            self.v_cache[:, :-shift, :] = self.v_cache[:, shift:, :].clone()  # (Cache shift)
            self.k_cache[:, -T:, :] = k  # Place new tokens at the end of the cache
            self.v_cache[:, -T:, :] = v  # Place new tokens at the end of the cache

        # Update the cache index (Added cache index update)
        self.cache_index = min(self.cache_index + T, block_size)

        # Compute attention scores using cached keys and values (No change)
        wei = q @ self.k_cache.transpose(2, 1) / self.head_size**0.5  # (B, T, block_size)

        # Apply causal masking to ensure autoregressive behavior
        mask = self.tril[:T, :block_size].to(x.device)  # Shape (T, block_size)
        wei = wei.masked_fill(mask == 0, float('-inf'))  # Broadcast mask across the batch
        wei = F.softmax(wei, dim=2)  # Apply softmax across the attention scores

        # Apply dropout to attention weights (No change)
        wei = self.dropout(wei)

        # Compute the final weighted sum of values (No change)
        out = wei @ self.v_cache  # (B, T, head_size)

        return out


### Explanation of Changes

The following changes were applied to the original `Head` class implementation to incorporate the KV (Key-Value) cache mechanism for more efficient attention computation:

#### 1. **KV Cache Initialization**:
   - **Purpose**: The KV cache stores the computed keys and values for efficient reuse during sequential token processing.
   - **Change**: I added a check in the `forward` method to initialize the `k_cache` and `v_cache` if they are `None` (i.e., the first time the function is run or the cache hasn't been initialized yet).
   ```python
   if self.k_cache is None or self.v_cache is None:
       self.k_cache = torch.zeros(B, block_size, self.head_size, device=x.device)
       self.v_cache = torch.zeros(B, block_size, self.head_size, device=x.device)
       self.cache_index = 0
   ```
   - **Explanation**: This ensures that the cache is only initialized once and stored on the same device as the input tensor `x`.

#### 2. **Cache Update Logic**:
   - **Purpose**: Efficiently update the KV cache with new keys and values while maintaining the previously stored values.
   - **Change**: Added logic to update the `k_cache` and `v_cache` in-place.
     - If there is enough space in the cache (`cache_index + T <= block_size`), the new keys and values are directly inserted at the current position in the cache.
     - If the cache is full, older values are shifted out, and new values are inserted at the end.
   ```python
   if self.cache_index + T <= block_size:
       self.k_cache[:, self.cache_index:self.cache_index + T, :] = k
       self.v_cache[:, self.cache_index:self.cache_index + T, :] = v
   else:
       shift = self.cache_index + T - block_size
       self.k_cache[:, :-shift, :] = self.k_cache[:, shift:, :].clone()
       self.v_cache[:, :-shift, :] = self.v_cache[:, shift:, :].clone()
       self.k_cache[:, -T:, :] = k
       self.v_cache[:, -T:, :] = v
   ```
   - **Explanation**: This mechanism ensures that only the most recent keys and values are retained in the cache while older values are discarded, making the process memory-efficient.

#### 3. **Cache Index Update**:
   - **Purpose**: Track the current position in the cache for where to insert new keys and values.
   - **Change**: After updating the cache, the `cache_index` is updated to reflect the new position.
   ```python
   self.cache_index = min(self.cache_index + T, block_size)
   ```
   - **Explanation**: This ensures that the cache index is incremented and doesn't exceed the maximum block size.

And this an illustration of the KV cache mechanism in action:

![Attention Mechanism](images/atten.gif)  


This is a simple illustration of how KV cache works

In [7]:
# Example to illustrate how the cache transforms

k_cache = torch.zeros(1, 3, 3)
v_cache = torch.zeros(1, 3, 3)

steps = 3
for i in range(steps):
    k_cache[:, i, :] = torch.randint(10, (1, 3))
print("k_cache Before:\n", k_cache)

shift = 1
k_cache[:, :-shift, :] = k_cache[:, shift:, :].clone()
v_cache[:, :-shift, :] = v_cache[:, shift:, :].clone()
print("k_cache After:\n", k_cache)


k_cache Before:
 tensor([[[4., 9., 3.],
         [0., 3., 9.],
         [7., 3., 7.]]])
k_cache After:
 tensor([[[0., 3., 9.],
         [7., 3., 7.],
         [7., 3., 7.]]])


Now lets test the heaed class with KV cache mechanism

In [8]:

# Test scenario
batch_size = 2  # Number of sequences in the batch
seq_length = 4  # Sequence length (number of tokens)
x = torch.randn(batch_size, seq_length, embed_size)  # Random input tensor

# Create an instance of the Head class
head = Head(head_size=head_size)

# Run a forward pass
output = head(x)

# Print the output shape
print(f"Output shape: {output.shape}")


Output shape: torch.Size([2, 4, 64])


And this is it guys… 

That’s how one would implement KV Cache into the Attention Mechanism. Now let’s move on to see how fast it makes the inference.

## Inference with and without KV-Cache
So far we’ve discussed the working and implementation of the KV-Cache optimization technique. Below is the code for our previous GPT implementation updated with KV-Cache. If you are new to my page you don’t need to go through every line, as the only thing that concerns this article is the Attention Head Class and the generate function.

In [9]:
from UTILS.load_weights import download_and_load_gpt2 , load_weights_into_gpt
from UTILS.model import GPTModel

## Choose the model to use
CHOOSE_MODEL = "gpt2-medium (355M)"

## Base configuration settings for the model
BASE_CONFIG = {
    "vocab_size": 50257,     # Size of the vocabulary used by the model
    "context_length": 1024,  # Maximum context length the model can handle
    "drop_rate": 0.0,        # Dropout rate for regularization
    "qkv_bias": True         # Whether to use bias terms in query, key, and value projections
}

## Dictionary containing configurations for different GPT-2 model sizes
model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},      # Config for small model
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},    # Config for medium model
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},     # Config for large model
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},       # Config for extra-large model
}

## Update the BASE_CONFIG with parameters specific to the chosen model
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])


model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval()


File already exists and is up-to-date: gpt2/355M/checkpoint
File already exists and is up-to-date: gpt2/355M/encoder.json
File already exists and is up-to-date: gpt2/355M/hparams.json
File already exists and is up-to-date: gpt2/355M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/355M/model.ckpt.index
File already exists and is up-to-date: gpt2/355M/model.ckpt.meta
File already exists and is up-to-date: gpt2/355M/vocab.bpe


GPTModel(
  (tok_emb): Embedding(50257, 1024)
  (pos_emb): Embedding(1024, 1024)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=1024, out_features=1024, bias=True)
        (W_key): Linear(in_features=1024, out_features=1024, bias=True)
        (W_value): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU()
          (2): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_resid): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_f

Lets borrow this code from the previous notebook

In [10]:
import torch
import torch.nn.functional as F
import tiktoken

def top_p_logits(logits, p=0.5):
    probs = F.softmax(logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    sorted_probs[cumulative_probs > p] = 0
    filtered_logits = torch.zeros_like(logits).to(logits.device)  # Move to same device
    filtered_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_probs)
    return filtered_logits

def generate(model, prompt, max_new_tokens, context_size, tokenizer, temperature=0.0, top_k=None, top_p=None, eos=None):
    # Detect device
    device = next(model.parameters()).device  # Get model's device

    # Encode and move input to the correct device
    idx = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)

    idx_gen = idx.clone()  # Start with the prompt indices

    for _ in range(max_new_tokens):
        idx_cond = idx_gen[:, -context_size:]

        with torch.no_grad():
            logits = model(idx_cond)  # Forward pass
            logits = logits[:, -1, :]  # Take the last token's logits
            
            # Apply top-k sampling
            if top_k is not None:
                top_k_values, _ = torch.topk(logits, k=top_k)
                min_value = top_k_values[:, -1].unsqueeze(1)  
                logits = torch.where(logits < min_value, torch.tensor(float('-inf')).to(device), logits)

            # Apply top-p sampling
            if top_p is not None:
                logits = top_p_logits(logits, p=top_p)  

            # Apply temperature
            if temperature > 0.0:
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)  # Convert to probabilities
                idx_next = torch.multinomial(probs, num_samples=1)  # Sample token
            else: 
                idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # Take max logit

            # EOS handling
            if eos is not None and torch.equal(idx_next, torch.tensor(eos).to(device)):
                break
            
            # Append new token
            idx_gen = torch.cat((idx_gen, idx_next), dim=1)

    return tokenizer.decode(idx_gen.squeeze(0).tolist())  # Convert tokens back to text


Testing the model without KV cache
here im using my laptop with gtx 1650 gpu 

In [11]:
import time
import torch
import tiktoken

# Detect device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Move model to detected device
gpt = model.to(device)

# Initialize tokenizer
tokenizer = tiktoken.get_encoding('gpt2')

# Tokenize input and move to device
prompt = "I have a dream that "
input_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)

# Measure time
start_time = time.time()

# Run text generation
generated_text = generate(
    model=gpt,  
    prompt=prompt,
    max_new_tokens=1000,
    context_size=512,
    tokenizer=tokenizer,
    temperature=0.75,
    top_k=None,
    top_p=None,
)

# Measure end time
end_time = time.time()

# Print generated text
print(generated_text)

# Print benchmark results
print(f"Inference Time: {end_time - start_time:.4f} seconds")


Using device: cuda
I have a dream that ____ is going to show up in the Oval Office tomorrow with a massive heart attack and be like, "What the hell am I going to do?"

AP Bill Clinton.

KIDS: The only thing that saved us from Great Depression was — the only thing that prevented World War II from happening was —

CONAN: I just don't think there's any question that you were...

KIDS: — the person who was able to get the economy going again.

CONAN: You know, I think while his public image was somewhat tarnished by his scandals, I think he did some good things during that time. I think he was a good fit for the public service. That wasn't to say that we didn't trust him with that responsibility.

KIDS: I think if you look at the people around him, they were OK with the way he ran things.

(LAUGHTER)

Conan: But, you know, I do think public service is a - you know, is an important and valuable part of life, but people have grown up in a different time when people weren't served well. I thi

As you can see the model the genration took 184.6997 seconds to generate 1000 tokens

Let see how the model performs with KV cache

But first we need to define the model with KV cache , since the model doesn't have KV cache by default

In [12]:
# Accessing and printing the MultiHeadAttention layers of the model
for i, block in enumerate(model.trf_blocks):
    # Each block is a TransformerBlock, and its attention layer is in the 'att' attribute
    multihead_attention = block.att
    print(f"MultiHeadAttention {i}:")
    print(multihead_attention)
    print("\n" + "="*50 + "\n")


MultiHeadAttention 0:
MultiHeadAttention(
  (W_query): Linear(in_features=1024, out_features=1024, bias=True)
  (W_key): Linear(in_features=1024, out_features=1024, bias=True)
  (W_value): Linear(in_features=1024, out_features=1024, bias=True)
  (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)


MultiHeadAttention 1:
MultiHeadAttention(
  (W_query): Linear(in_features=1024, out_features=1024, bias=True)
  (W_key): Linear(in_features=1024, out_features=1024, bias=True)
  (W_value): Linear(in_features=1024, out_features=1024, bias=True)
  (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)


MultiHeadAttention 2:
MultiHeadAttention(
  (W_query): Linear(in_features=1024, out_features=1024, bias=True)
  (W_key): Linear(in_features=1024, out_features=1024, bias=True)
  (W_value): Linear(in_features=1024, out_features=1024, bias=True)
  (out_proj): Linear(in_features=102

In [2]:
import torch.nn as nn
import torch
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        
        # We'll create the mask dynamically in forward pass instead
        self.max_length = context_length

    def forward(self, x, past_kvs=None):
        b, num_tokens, d_in = x.shape

        # Create queries, keys, values
        queries = self.W_query(x)  # (b, num_tokens, d_out)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Reshape for multi-head attention
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)

        # Handle KV cache
        if past_kvs is not None:
            past_keys, past_values = past_kvs
            keys = torch.cat([past_keys, keys], dim=2)
            values = torch.cat([past_values, values], dim=2)

        # Calculate attention scores
        seq_len = keys.size(2)  # Total sequence length including past
        attn_scores = queries @ keys.transpose(-2, -1)  # (b, num_heads, num_tokens, seq_len)

        # Create causal mask dynamically
        device = queries.device
        mask = torch.ones((1, 1, num_tokens, seq_len), dtype=torch.bool, device=device)
        mask = torch.triu(mask.squeeze(0).squeeze(0), diagonal=1)
        mask = mask.unsqueeze(0).unsqueeze(0)

        # Apply mask and scaling
        attn_scores = attn_scores / math.sqrt(self.head_dim)
        attn_scores = attn_scores.masked_fill(mask, float('-inf'))
        
        # Apply attention
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = (attn_weights @ values)  # (b, num_heads, num_tokens, head_dim)

        # Reshape output
        context_vec = context_vec.transpose(1, 2).contiguous()  # (b, num_tokens, num_heads, head_dim)
        context_vec = context_vec.view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec, (keys, values)

class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift


class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))


class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)


class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_resid = nn.Dropout(cfg["drop_rate"])

    def forward(self, x, past_kvs=None):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x, new_kvs = self.att(x, past_kvs)   # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_resid(x)
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed-forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + shortcut  # Add the original input back

        return x, new_kvs


class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx, past_kvs=None):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        
        new_past_kvs = []
        for i, trf_block in enumerate(self.trf_blocks):
            x, kv = trf_block(x, past_kvs[i] if past_kvs is not None else None)
            new_past_kvs.append(kv)

        x = self.final_norm(x)
        logits = self.out_head(x)
        
        # Return logits and updated KV cache
        return logits, new_past_kvs


In [3]:
from UTILS.load_weights import download_and_load_gpt2, load_weights_into_gpt

## Choose the model to use
CHOOSE_MODEL = "gpt2-medium (355M)"

## Base configuration settings for the model
BASE_CONFIG = {
    "vocab_size": 50257,     # Size of the vocabulary used by the model
    "context_length": 1024,  # Maximum context length the model can handle
    "drop_rate": 0.0,        # Dropout rate for regularization
    "qkv_bias": True         # Whether to use bias terms in query, key, and value projections
}

## Dictionary containing configurations for different GPT-2 model sizes
model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},      # Config for small model
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},    # Config for medium model
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},     # Config for large model
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},       # Config for extra-large model
}

## Update the BASE_CONFIG with parameters specific to the chosen model
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

model_cache = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model_cache, params)  # Load weights into model_cache
model_cache.eval()  # Set the model in evaluation mode


2025-02-18 00:09:07.300536: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-18 00:09:07.313200: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739826547.335308  101344 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739826547.341483  101344 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-18 00:09:07.362836: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appro

File already exists and is up-to-date: gpt2/355M/checkpoint
File already exists and is up-to-date: gpt2/355M/encoder.json
File already exists and is up-to-date: gpt2/355M/hparams.json
File already exists and is up-to-date: gpt2/355M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/355M/model.ckpt.index
File already exists and is up-to-date: gpt2/355M/model.ckpt.meta
File already exists and is up-to-date: gpt2/355M/vocab.bpe


GPTModel(
  (tok_emb): Embedding(50257, 1024)
  (pos_emb): Embedding(1024, 1024)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=1024, out_features=1024, bias=True)
        (W_key): Linear(in_features=1024, out_features=1024, bias=True)
        (W_value): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU()
          (2): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_resid): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_f

In [4]:
import torch
import torch.nn.functional as F
import tiktoken

def top_p_logits(logits, p=0.5):
    probs = F.softmax(logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    sorted_probs[cumulative_probs > p] = 0
    filtered_logits = torch.zeros_like(logits).to(logits.device)  # Move to same device
    filtered_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_probs)
    return filtered_logits

def generate(model, prompt, max_new_tokens, context_size, tokenizer, temperature=0.0, top_k=None, top_p=None, eos=None):
    # Detect device
    device = next(model.parameters()).device  # Get model's device

    # Encode and move input to the correct device
    idx = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)  # Add batch dimension

    idx_gen = idx.clone()  # Start with the prompt indices
    kv_cache = None  # Initialize KV cache as None for the first pass

    for _ in range(max_new_tokens):
        idx_cond = idx_gen[:, -context_size:]

        with torch.no_grad():
            logits, kv_cache = model(idx_cond, kv_cache)  # Pass KV cache to model's forward pass

            # Ensure logits and idx_next have matching shapes
            batch_size = logits.size(0)  # Get batch size
            vocab_size = logits.size(-1)  # Get vocabulary size (50257)
            
            # Handle the attention mask size issue by creating a dynamic mask
            num_tokens = idx_cond.size(1)
            mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1).to(device)
            mask = mask.bool()  # Create the attention mask for causal attention

            # Apply the mask
            logits[:, :, :num_tokens].masked_fill_(mask, -float('inf'))
            logits = logits[:, -1, :]  # Take the last token's logits

            # Apply top-k sampling
            if top_k is not None:
                top_k_values, _ = torch.topk(logits, k=top_k)
                min_value = top_k_values[:, -1].unsqueeze(1)  
                logits = torch.where(logits < min_value, torch.tensor(float('-inf')).to(device), logits)

            # Apply top-p sampling
            if top_p is not None:
                logits = top_p_logits(logits, p=top_p)  

            # Apply temperature
            if temperature > 0.0:
                logits = logits / temperature
                probs = torch.softmax(logits, dim=-1)  # Convert to probabilities
                idx_next = torch.multinomial(probs, num_samples=1)  # Sample token
            else: 
                idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # Take max logit

            # EOS handling
            if eos is not None and torch.equal(idx_next, torch.tensor(eos).to(device)):
                break
            
            # Append new token to the generated sequence
            idx_gen = torch.cat((idx_gen, idx_next), dim=1)

    return tokenizer.decode(idx_gen.squeeze(0).tolist())  # Convert tokens back to text

In [5]:
import torch
import tiktoken
import time

# Assuming you're using tiktoken for GPT-2 (this may vary depending on your tokenizer)
tokenizer = tiktoken.get_encoding("gpt2")

device = "cuda" if torch.cuda.is_available() else "cpu"  # Check if GPU is available
# Example prompt text
prompt = "I have a dream that "
model_cache = model_cache.to(device)  # Move model to device

# Handle EOS token
eos_token = "<|endoftext|>"  # Define the EOS token
eos_token_id = tokenizer.encode(eos_token, allowed_special={eos_token})[0]  # Allow EOS token

start_time = time.time()


# Run the generation function
generated_text = generate(
    model=model_cache,  # Your loaded model here
    prompt=prompt,
    max_new_tokens=1000,
    context_size=512,
    tokenizer=tokenizer,
    temperature=0/75,
    top_k=None,
    top_p=None,
    eos=eos_token_id
)

# Measure end time
end_time = time.time()

# Print generated text
print(generated_text)

# Print benchmark results
print(f"Inference Time: {end_time - start_time:.4f} seconds")







OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB. GPU 0 has a total capacity of 3.81 GiB of which 19.12 MiB is free. Including non-PyTorch memory, this process has 3.79 GiB memory in use. Of the allocated memory 3.08 GiB is allocated by PyTorch, and 646.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)