<a href="https://colab.research.google.com/github/sjgoodlife/LLM/blob/main/Comparision_of_inference_time_with_and_without_KV_cache.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set HF token

In [1]:
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')

import os
os.environ['HF_TOKEN'] = hf_token

## Load a model and tokenizer

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import numpy as np
import matplotlib.pyplot as plt

def setup_model():
    model_id = "meta-llama/Llama-3.2-1B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return model, tokenizer

## With/Without KV cache

- DynamicCache: Store KV Cache to this class
- cache_position: Assign location to store KV cache
- Pass past_key_values as an argument to the model

model.generate(use_cache=True)

In [4]:
# Refernce: https://huggingface.co/docs/transformers/main/kv_cache

In [7]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

def generate_with_kv_cache(model, tokenizer, prompt, max_new_tokens=100): # Add prompt to the model
    messages = [{"role": "user", "content": prompt}]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True
    ).to("cuda:0")

    past_key_values = DynamicCache() # Create an instance of dynamic cache
    generated_ids = inputs.input_ids
    cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")

    times = []

    for _ in range(max_new_tokens):
        start_time = time.time()
        outputs = model(
            **inputs, # Provide a prompt
            cache_position=cache_position, # Where to store KV cache
            past_key_values=past_key_values, # Pass previously stored KV cache to the model
        )
        token_time = time.time() - start_time
        next_token_ids = outputs.logits[:, -1:].argmax(-1) # Get a next token
        generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) # Add a new output to the last dimension

        attention_mask = inputs["attention_mask"]
        attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
        inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
        cache_position = cache_position[-1:] + 1 # Position of next KV cache

        times.append(token_time)

    return times

In [8]:
def generate_without_kv_cache(model, tokenizer, prompt, max_new_tokens=10): # Without KV cache to compare its result with KV cache
    messages = [{"role": "user", "content": prompt}]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True
    ).to("cuda:0")

    generated_ids = inputs.input_ids

    times = []


    for _ in range(max_new_tokens):
        start_time = time.time()
        outputs = model(
            **inputs,
            use_cache=False
        )
        token_time = time.time() - start_time
        next_token_ids = outputs.logits[:, -1:].argmax(-1)
        generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)

        attention_mask = inputs["attention_mask"]
        attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
        inputs = {"input_ids": generated_ids, "attention_mask": attention_mask}

        times.append(token_time)

    return times

In [9]:
def plot_comparison(results, max_new_tokens):
    plt.figure(figsize=(12, 8))

    tokens = list(range(1, max_new_tokens + 1))

    plt.plot(tokens, results['kv'], 'b-', label='KV Cache')
    plt.plot(tokens, results['no_kv'], 'r-', label='No KV Cache')

    plt.xlabel('Number of Tokens')
    plt.ylabel('Time by Token (seconds)')
    plt.legend()
    plt.grid(True)

    plt.savefig('time_comparison.png')
    plt.close()

def run_comparison(model, tokenizer, prompt, max_new_tokens=10):
    # Warm-up run
    _ = model.generate(**tokenizer("how are you?", return_tensors="pt").to("cuda"), max_new_tokens=1)

    # 1. With KV cache
    times_explicit_kv = generate_with_kv_cache(model, tokenizer, prompt, max_new_tokens)

    # 2. Without KV cache
    times_no_kv = generate_without_kv_cache(model, tokenizer, prompt, max_new_tokens)

    results = {
        'kv': times_explicit_kv,
        'no_kv': times_no_kv,
    }

    plot_comparison(results, max_new_tokens)

In [10]:
model, tokenizer = setup_model()

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [12]:
prompt = "Can you provide a comprehensive explanation of quantum entanglement and its implications for quantum computing? I'm particularly interested in understanding how entangled particles maintain their correlation regardless of distance, the concept of quantum superposition, and how these principles are being applied in the development of quantum computers. Also, could you discuss the potential impact of quantum computing on current encryption methods and data security systems?"
run_comparison(model, tokenizer, prompt=prompt, max_new_tokens=100)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
