# Demo: KV-Caching in Action (Conceptual)

**Goal:** This demo focuses on showing how `model.generate()` uses Key-Value (KV) Caching by default and conceptually explaining the internal process.

We will walk through:
1. **Setup:** Installing libraries and configuring the environment.
2. **Loading:** Loading a model (`Llama-3.2-1B`) and its tokenizer.
3. **Generation:** Calling `model.generate()` to see KV caching in action.
4. **Explanation:** Breaking down what happens under the hood.

## Step 1: Setup Environment

In [None]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

## Step 2: Load Model and Tokenizer

We'll define our model configuration. We'll use `meta-llama/Llama-3.2-1B`. We also check for an available CUDA GPU to speed things up.

In [None]:
os.environ["HF_HUB_OFFLINE"] = "1"
model_name = "/voc/shared/models/llama/Llama-3.2-1B"

device = "cuda" if torch.cuda.is_available() else "cpu"

# Use bfloat16 for faster computation if supported on CUDA, otherwise use standard float32
dtype = torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

print(f"Using model: {model_name}")
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")

Using model: meta-llama/Llama-3.2-1B
Using device: cuda
Using dtype: torch.float32


Now, let's load the tokenizer and the model from the Hugging Face Hub.

In [5]:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Tokenizer loaded successfully.")

print("\nLoading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
).to(device)
print("Model loaded successfully and moved to device.")

# Set the model to evaluation mode (disables dropout, etc.)
model.eval()

Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

Tokenizer loaded successfully.

Loading model...


config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

Model loaded successfully and moved to device.


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

Some models don't have a `pad_token` set by default. We'll set it to the `eos_token` (end-of-sequence) to prevent warnings during generation.

In [6]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id
    print("Pad token set to EOS token.")

Pad token set to EOS token.


## Step 3: Generate Text with KV Caching

First, we define our prompt and tokenize it, preparing it for the model.

In [7]:
prompt = "The best way to optimize LLM inference is"
print(f"Prompt: \"{prompt}\"")

inputs = tokenizer(prompt, return_tensors="pt").to(device)

Prompt: "The best way to optimize LLM inference is"


Now for the main event. We call `model.generate()`. 

Crucially, the `use_cache=True` argument is **on by default** for most autoregressive models. This is what enables KV caching. We are explicitly writing it here to make it clear, but you usually don't have to.

In [8]:
print("Running model.generate() with use_cache=True (default behavior)...")

start_time = time.perf_counter()
with torch.no_grad():
    outputs = model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
        use_cache=True, # This enables/confirms KV Caching
        pad_token_id=tokenizer.pad_token_id
    )
end_time = time.perf_counter()

generated_ids = outputs[0]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

print(f"--- Generated Text ---")
print(generated_text)
print("----------------------")
print(f"(Generation took: {end_time - start_time:.4f} seconds)")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Running model.generate() with use_cache=True (default behavior)...
--- Generated Text ---
The best way to optimize LLM inference is to use a large, diverse dataset. We have a large, diverse dataset of 4.3 million English Wikipedia articles. We are using this dataset to train our LLMs on a variety of tasks, including question answering, summarization, and
----------------------
(Generation took: 7.5997 seconds)


## Step 4: Conceptual Breakdown - What Just Happened?

The `model.generate()` call did all the heavy lifting for us, but it was managing a cache behind the scenes. Let's break down how.

### Stage 1: Prompt Processing (The First Pass)

When generation starts, the model first processes the entire input prompt (`"The best way to optimize LLM inference is"`) in a single forward pass.

- For **every token** in this prompt, the model calculates its corresponding **Key (K)** and **Value (V)** vectors.
- These K and V vectors (for the entire prompt) are then stored in a cache, often called `past_key_values`.

This initial step is computationally intensive but is only done **once**.

### Stage 2: Autoregressive Generation (The Token-by-Token Loop)

Now, the model generates the rest of the text one token at a time. This is where the cache becomes critical.

- **To generate the 1st new token:**
    - The model only needs to process the *last token of the prompt*.
    - It calculates the **Query (Q)** for this token and uses it to attend to *all the K and V vectors already in the cache*.
    - After predicting the new token, it calculates the K and V for *this new token only* and **appends** them to the cache.

- **To generate the 2nd new token:**
    - The model only processes the *1st new token* it just generated.
    - It calculates its Q vector and attends to the *entire, updated cache* (prompt tokens + 1st new token).
    - It then computes the K and V for the 2nd new token and appends them to the cache.

This loop continues, and at each step, the model avoids re-calculating K and V for all previous tokens. It's a massive computational saving!

### The Magic of Library Abstraction

Libraries like Hugging Face Transformers abstract away this complex state management. By simply using `model.generate()`, we automatically get the benefits of KV Caching without needing to manually handle the `past_key_values` object at each step.

### Final Takeaways

- **What it is:** A technique to store and reuse Key/Value vectors of past tokens during autoregressive generation.
- **Why it's used:** To dramatically **reduce computation** and **lower latency** (speed up generation time per token).
- **The Trade-off:** It consumes **more memory (VRAM)** because the cache grows linearly with every new token generated.