# Lesson 2 - Exercise 1: Quantifying KV Cache Memory Growth

**Goal:** Empirically measure and analyze the growth in peak GPU memory consumption during LLM inference as the number of generated tokens increases, thereby quantifying the memory footprint of the Key-Value (KV) Cache.

## 1. Installation & Setup

In [None]:
!pip install torch transformers accelerate matplotlib datasets rouge_score kagglehub evaluate huggingface_hub

In [None]:
import os
# TODO: Replace "YOUR_HUGGING_FACE_TOKEN_HERE" with your actual Hugging Face token
os.environ["HF_TOKEN"] = "YOUR_HUGGING_FACE_TOKEN_HERE" 


In [None]:
!huggingface-cli login --token $HF_TOKEN

## 2. Imports and Configuration

In [None]:
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc # For garbage collection
import matplotlib.pyplot as plt

# --- Configuration ---
model_name = "meta-llama/Llama-3.2-1B" 

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("WARNING: This exercise requires a CUDA-enabled GPU for meaningful memory measurements. Results on CPU will not reflect KV cache growth on GPU.")

# TODO: Determine appropriate dtype
if device == "cuda":
    if torch.cuda.is_bf16_supported():
        dtype = torch.bfloat16
        print("Using torch.bfloat16")
    else:
        dtype = torch.float16
        print("Using torch.float16 (bf16 not supported)")
else:
    dtype = torch.float32 # Default for CPU
    print("Using torch.float32 for CPU")

prompt = "The best way to optimize LLM inference is"

# TODO: Define a list of different max_new_tokens values to test.
generation_lengths = [] 

print(f"Device: {device}")
print(f"Model: {model_name}")
print(f"Prompt: '{prompt}'")
print(f"Generation lengths to test: {generation_lengths}")

## 3. Helper Function to Measure Peak Memory

In [None]:
def get_peak_gpu_memory_mb(device_to_check="cuda"):
    """Records the peak GPU memory allocated since the last reset for the specified device."""
    if torch.cuda.is_available() and device_to_check == "cuda":
        torch.cuda.synchronize() # Ensure all CUDA operations are complete
        peak_mem_bytes = # TODO: check the documentation https://docs.pytorch.org/docs/stable/cuda.html for function that can give peak memory
        return peak_mem_bytes / (1024 * 1024) # Convert to MB
    return 0 # Return 0 if not on CUDA or for CPU

## 4. Load Model and Tokenizer

In [None]:
print(f"\nLoading model and tokenizer for {model_name}...")
try:
    # TODO: Load the tokenizer using AutoTokenizer.from_pretrained()
    tokenizer = None

    # TODO: Set tokenizer.pad_token if it's None (e.g., tokenizer.pad_token = tokenizer.eos_token)

    # TODO: Load the model using AutoModelForCausalLM.from_pretrained()
    #       Specify torch_dtype=dtype and attn_implementation="sdpa"
    #       Move the model to the specified device (.to(device))
    model = None

    # TODO: Set the model to evaluation mode using model.eval()

    print("Model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading model/tokenizer: {e}")
    # exit() # Consider exiting if loading fails

## 5. Test Configuration & Measurement Loop

In [None]:
peak_memory_results_mb = []
actual_tokens_generated_list = []
generation_times = []

print("\n--- Measuring KV Cache Memory Growth ---")

if device == "cuda" and model is not None: # Proceed only if model loaded and on CUDA
    # Initial memory measurement (model loaded, no generation yet)
    torch.cuda.reset_peak_memory_stats(device) # Reset before taking baseline
    _ = model(torch.tensor([[0]], device=device)) # Dummy forward pass to ensure model is fully on GPU
    torch.cuda.synchronize(device)
    initial_peak_memory_mb = get_peak_gpu_memory_mb(device)
    print(f"Initial peak memory (model loaded on GPU): {initial_peak_memory_mb:.2f} MB")

    for length in generation_lengths:
        print(f"\nGenerating {length} new tokens...")
        
        # TODO: Tokenize the prompt using the loaded tokenizer.
        #       Ensure tensors are moved to the correct device.
        # inputs = tokenizer(prompt, return_tensors="pt").to(device)
        inputs = None # Placeholder

        if inputs is None:
            print("  Skipping due to uninitialized inputs.")
            peak_memory_results_mb.append(float('nan'))
            actual_tokens_generated_list.append(0)
            generation_times.append(float('nan'))
            continue

        # TODO: Reset GPU memory stats before this specific generation run
        # torch.cuda.reset_peak_memory_stats(device)
        
        start_time = time.perf_counter()
        with torch.no_grad():
            try:
                # TODO: Generate tokens using model.generate()
                #       - Pass inputs["input_ids"]
                #       - Set max_new_tokens=length
                #       - Ensure use_cache=True (it's usually default but good to be explicit)
                #       - Set pad_token_id=tokenizer.pad_token_id
                outputs = None # Placeholder

                if outputs is None:
                    raise ValueError("Outputs from model.generate() is None.")

                # TODO: Calculate actual number of new tokens generated
                num_input_tokens = # tokens in the input prompt
                num_output_tokens = # tokens in the output
                actual_new_tokens = num_output_tokens - num_input_tokens
                actual_tokens_generated_list.append(actual_new_tokens)

            except Exception as e:
                print(f"  Error during generation for length {length}: {e}")
                peak_memory_results_mb.append(float('nan')) 
                actual_tokens_generated_list.append(0)
                generation_times.append(float('nan'))
                continue

        # TODO: Synchronize CUDA operations before recording memory and time
        # torch.cuda.synchronize(device)
        end_time = time.perf_counter()
        gen_time = end_time - start_time
        generation_times.append(gen_time)

        # TODO: Record peak GPU memory allocated during THIS generation run
        current_peak_memory_mb = 0 # Placeholder
        peak_memory_results_mb.append(current_peak_memory_mb)

        print(f"  Actual new tokens generated: {actual_new_tokens}")
        print(f"  Peak memory allocated for this generation: {current_peak_memory_mb:.2f} MB")
        print(f"  Generation time: {gen_time:.4f} seconds")

        # Cleanup
        del outputs
        del inputs
        torch.cuda.empty_cache()
        gc.collect()
else:
    print("Skipping measurement loop: Model not loaded or not on CUDA device.")

## 6. Deliverables: Analysis and Reporting

In [None]:
print("\n--- Results Summary ---")
print("Max New Tokens | Actual New Tokens | Peak Memory (MB) during Gen | Time (s)")
print("---------------|-------------------|-----------------------------|----------")
for i, requested_length in enumerate(generation_lengths):
    if i < len(peak_memory_results_mb) and i < len(actual_tokens_generated_list) and i < len(generation_times):
        actual_len = actual_tokens_generated_list[i]
        mem_val = peak_memory_results_mb[i]
        time_val = generation_times[i]
        print(f"{requested_length:15} | {actual_len:17} | {mem_val:27.2f} | {time_val:8.4f}")
    else:
        # This case handles if a generation run failed and wasn't recorded
        print(f"{requested_length:15} | {'N/A':17} | {'N/A':27} | {'N/A':8}")

# TODO: Trend Analysis:
# 1. Describe the relationship you observe between 'Actual New Tokens' and 'Peak Memory (MB) during Gen'.
#    Is it linear? Does it increase as expected?
#    (You can also consider 'Peak Memory (MB) during Gen' - initial_peak_memory_mb to see the growth above the model size).

# TODO: Estimate Memory per Token:
# 2. Calculate an estimate of how much *additional* memory the KV cache consumes per *actually generated new token*.
#    - Select two data points from your results (e.g., generation for 50 tokens and 150 tokens).
#    - Calculate ΔMemory = PeakMemory_LongerRun - PeakMemory_ShorterRun
#    - Calculate ΔTokens = ActualTokens_LongerRun - ActualTokens_ShorterRun
#    - Estimate MemoryPerToken = ΔMemory / ΔTokens (if ΔTokens > 0)
#    - Print your chosen data points and the calculated memory per token.

# TODO: Practical Impact Discussion:
# 3. Briefly discuss how this growing KV cache memory footprint can become a limiting factor for generating 
#    very long sequences, especially on GPUs with constrained VRAM. How does this relate to the concept 
#    of "context window limits" in practice (even if the model theoretically supports a large window)?

# --- Plotting ---
if device == "cuda" and any(m > 0 for m in peak_memory_results_mb if isinstance(m, (int, float)) and not torch.isnan(torch.tensor(m))):
    # Filter out NaN or non-numeric before plotting
    valid_indices = [i for i, mem in enumerate(peak_memory_results_mb) 
                     if isinstance(mem, (int, float)) and not torch.isnan(torch.tensor(mem))]
    plot_lengths = [actual_tokens_generated_list[i] for i in valid_indices]
    plot_memory = [peak_memory_results_mb[i] for i in valid_indices]
    
    if plot_lengths and plot_memory and len(plot_lengths) == len(plot_memory):
        try:
            plt.figure(figsize=(10, 6))
            plt.plot(plot_lengths, plot_memory, marker='o', linestyle='-')
            plt.title(f'Peak GPU Memory vs. Number of Actual New Tokens Generated\nModel: {model_name}')
            plt.xlabel('Number of Actual New Tokens Generated')
            plt.ylabel('Peak GPU Memory Allocated During Generation (MB)')
            plt.grid(True)
            # plt.xticks(plot_lengths) # May be too crowded if many points
            plt.tight_layout()
            plot_filename = "kv_cache_memory_growth.png"
            plt.savefig(plot_filename)
            print(f"\nPlot saved to {plot_filename}")
            plt.show()
        except Exception as e:
            print(f"Could not generate plot: {e}. Ensure matplotlib is installed.")
    else:
        print("Not enough valid data points to plot.")
elif device != "cuda":
    print("\nPlotting is designed for CUDA memory measurements.")

print("\nExercise Complete. Please fill in the TODO sections for analysis.")