# GPT2 Key-Value Cache Visualization

This notebook visualizes how key-value caches work in GPT2-style transformer models across all layers and attention heads. We'll examine:
1. The structure of KV caches
2. How they're updated during generation
3. Attention patterns with cached keys and values

In [50]:
# Load model and tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [51]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import GPT2LMHeadModel, GPT2Tokenizer


device = 'cuda'

# For visualization
def plot_all_attention_patterns(key_states, value_states, attention_weights, tokens, title="Attention Patterns"):
    num_layers = len(key_states)
    num_heads = key_states[0].shape[1]
    
    # Create a large figure for all layers and heads
    fig = plt.figure(figsize=(20, 20))
    plt.suptitle(title, fontsize=16, y=0.95)
    
    # Add text showing input tokens
    plt.figtext(0.02, 0.98, f'Input text: {tokens}', fontsize=12, wrap=True)
    
    # Plot matrix for each layer and head
    num_layers = 5
    num_heads = 6
    for layer in range(num_layers):
        for head in range(num_heads):
            # Calculate subplot position
            plt.subplot(num_layers, num_heads, layer * num_heads + head + 1)
            
            # Plot attention weights with token labels
            sns.heatmap(
                attention_weights[layer][0, head].detach().cpu().numpy(),
                xticklabels=tokens,
                yticklabels=tokens,
                cmap='viridis',
                square=True,
                cbar=False,
            )
            
            if head == 0:
                plt.ylabel(f'Layer {layer}')
            if layer == 0:
                plt.title(f'Head {head}')
            
            # Rotate labels for better readability
            plt.xticks(rotation=45, ha='right')
            plt.yticks(rotation=0)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.92)
    plt.show()

def plot_kv_cache_details(key_states, value_states, tokens, layer=0, head=0):
    plt.figure(figsize=(15, 5))
    
    # Plot 1: Key States
    plt.subplot(131)
    sns.heatmap(
        key_states[layer][0, head].detach().cpu().numpy(),
        xticklabels=range(key_states[layer].shape[-1]),
        yticklabels=tokens,
        cmap='viridis'
    )
    plt.title(f'Key States (Layer {layer}, Head {head})')
    plt.xlabel('Embedding Dimension')
    
    # Plot 2: Value States
    plt.subplot(132)
    sns.heatmap(
        value_states[layer][0, head].detach().cpu().numpy(),
        xticklabels=range(value_states[layer].shape[-1]),
        yticklabels=tokens,
        cmap='viridis'
    )
    plt.title(f'Value States (Layer {layer}, Head {head})')
    plt.xlabel('Embedding Dimension')
    
    plt.tight_layout()
    plt.show()

# Extract KV cache from model
@torch.no_grad()
def extract_kv_cache(model, input_ids):
    # Get model outputs with attention
    outputs = model(input_ids, output_attentions=True, use_cache=True)
    
    # Extract all layers' cache
    all_key_states = []
    all_value_states = []
    for layer_cache in outputs.past_key_values:
        all_key_states.append(layer_cache[0])  # Shape: [batch, num_heads, seq_len, head_dim]
        all_value_states.append(layer_cache[1])
    
    return all_key_states, all_value_states, outputs.attentions

In [52]:
key_states[0].shape

torch.Size([1, 12, 8, 64])

In [53]:
model.device


device(type='cuda', index=0)

In [54]:
generated_ids = model.generate(input_ids.to(device), do_sample=True, max_length=200)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [55]:
tokenizer.batch_decode(generated_ids)[0]

"Human: can fish fly?\nAssistant answer: fish.\nAssistant answer: I mean that's a big deal.\nAssistant answer: OK.\nAssistant answer: So now, what are you going to do if you get caught?\nAssistant answer: The only option is to go to Fish, and fish can fly.\nAssistant answer: Okay.\nAssistant answer: Ok. So you've been caught. So you have a chance, I mean, one of their chances is to catch an enormous fish.\nAssistant answer: Ok.\nAssistant answer: It's not a chance unless you have an experience.\nAssistant answer: And what?\nAssistant answer: You know, I have three fish. You can fish just about one or two fish a year, or you can catch multiple fish at once.\nAssistant answer: OK.\nAssistant answer: So this is, um, there was one great fish, one of them was a fish, it was, he ran,"

## Analyzing the Visualization

The visualizations above show:
1. **Attention Patterns**: How each attention head in each layer attends to different input tokens
2. **Key-Value Cache Details**: Detailed view of how tokens are encoded in the key and value states

During generation, the model:
1. Stores computed key-value pairs in cache
2. Reuses them for subsequent tokens
3. Only computes new KV pairs for new tokens

In [58]:
# Prepare multiple test cases
test_prompts = [
    "Article: A car crash occurred on Highway 15 today, involving three cars. Authorities were called to the scene shortly after the crash was reported, and an investigation is underway to determine the cause of the accident. Paramedics arrived on the scene, offering medical assistance to those involved in the crash, and all three individuals were taken to the nearby hospital for further treatment. TL;DR:",
    # "Who was the author of the art of war?",
]


In [60]:
# Test each prompt
for prompt in test_prompts:
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
    
    # Generate continuation
    output = model.generate(
        input_ids,
        max_length=500,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id,
        temperature=0.7,
        do_sample=True
    )
    
    # Print results
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    print("\nInput prompt:", prompt)
    print("Generated response:", generated_text)
    
    # Extract and visualize KV cache
    # key_states, value_states, attention = extract_kv_cache(model, input_ids)
    # print('encoded')


Input prompt: Article: A car crash occurred on Highway 15 today, involving three cars. Authorities were called to the scene shortly after the crash was reported, and an investigation is underway to determine the cause of the accident. Paramedics arrived on the scene, offering medical assistance to those involved in the crash, and all three individuals were taken to the nearby hospital for further treatment. TL;DR:
Generated response: Article: A car crash occurred on Highway 15 today, involving three cars. Authorities were called to the scene shortly after the crash was reported, and an investigation is underway to determine the cause of the accident. Paramedics arrived on the scene, offering medical assistance to those involved in the crash, and all three individuals were taken to the nearby hospital for further treatment. TL;DR: There were no injuries reported. No further information was available today.

Update: A car was found lying in the roadway near Haverhill, and the driver of 