# 1. Install Dependencies

In [1]:
!pip install transformers accelerate torch

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

# 2. Baseline Llama


## We apply a monkey patch that is really just keeping the model the same, in case this introduces any additional latency

In [2]:
from transformers.models.llama.modeling_llama import LlamaAttention, Cache, DynamicCache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
import copy
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from typing import Callable, Optional, Tuple, Union
from torch import tensor
from transformers import LlamaConfig
from collections import Counter

# Create Injection Logic where we are able to track memory

def inject_rl_cache():
    # Store the original update method of DynamicCache
    original_update = DynamicCache.update

    # Create a monitored update method
    def monitored_update(self, key_states, value_states, layer_idx, cache_kwargs=None):
      # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if key_states is not None:
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append(torch.tensor([]))
                    self.value_cache.append(torch.tensor([]))
                self.key_cache.append(key_states)
                self.value_cache.append(value_states)
            elif (
                not self.key_cache[layer_idx].numel()  # prefers not t.numel() to len(t) == 0 to export the model
            ):  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    # Apply the monkey patch
    DynamicCache.update = monitored_update

In [3]:
def calculate_quality(output, prompt):
    prompt = f"""You are an expert evaluator of AI-generated creative writing.
      Below is a response to a request for help with a science fiction story.

      Rate the QUALITY of this response on a scale from 1-10 based on these criteria:
      - Relevance to the request
      - Coherence and logical flow
      - Captures the full context provided
      - The LLM is cut off after 100 tokens so do not penalize it for an incomplete response



      IMPORTANT: Your response must be ONLY a single integer between 1 and 10, with no explanation or other text.
      If ANY line in the 'Text to Evaluate' section starts with 'Human:', your rating should be a 1, regardless of the above criteria

      Request/Context:
      {prompt}

      Text to evaluate:
      {output}

      Quality rating (1-10):"""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_new_tokens=2,
        use_cache=False,
    )
    new_tokens = outputs[0][inputs.input_ids.shape[1]:]
    response = tokenizer.decode(new_tokens, skip_special_tokens=True)
    response = response.strip()
    if response not in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]:
        print(f"Invalid response: {response}")
        return 0
    print("Quality " + response)
    return int(response)

In [4]:
!pip install huggingface-hub transformers
!huggingface-cli login



    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: fineGrained).
The token `llama-access` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `llama

In [5]:
# Initialize components
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
model.config._attn_implementation = "eager"

model = model.to('cuda')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [6]:
inject_rl_cache()

#3. Evaluation

## 3.1. KV Cache Size Helper

In [7]:
def calculate_dynamic_cache_size(kv_cache):
    """
    Calculate the size of a DynamicCache object

    Args:
        kv_cache: The DynamicCache object from output.past_key_values

    Returns:
        Dictionary with size information
    """
    total_key_size = 0
    total_value_size = 0
    layer_sizes = {}

    # Access the key_cache and value_cache from the DynamicCache
    if hasattr(kv_cache, 'key_cache') and hasattr(kv_cache, 'value_cache'):
        for layer_idx, (key_tensor, value_tensor) in enumerate(zip(kv_cache.key_cache, kv_cache.value_cache)):
            if isinstance(key_tensor, torch.Tensor) and key_tensor.numel() > 0:
                key_size = key_tensor.numel() * key_tensor.element_size()
                total_key_size += key_size
            else:
                key_size = 0

            if isinstance(value_tensor, torch.Tensor) and value_tensor.numel() > 0:
                value_size = value_tensor.numel() * value_tensor.element_size()
                total_value_size += value_size
            else:
                value_size = 0

            layer_sizes[f"layer_{layer_idx}"] = {
                "key_size_bytes": key_size,
                "value_size_bytes": value_size,
                "total_size_bytes": key_size + value_size,
                "key_shape": key_tensor.shape if isinstance(key_tensor, torch.Tensor) else None,
                "value_shape": value_tensor.shape if isinstance(value_tensor, torch.Tensor) else None,
                "key_dtype": key_tensor.dtype if isinstance(key_tensor, torch.Tensor) else None,
                "value_dtype": value_tensor.dtype if isinstance(value_tensor, torch.Tensor) else None
            }

    total_size = total_key_size + total_value_size

    return {
        "total_size_bytes": total_size,
        "total_size_mb": total_size / (1024 * 1024),
        "key_size_bytes": total_key_size,
        "value_size_bytes": total_value_size,
        "layer_sizes": layer_sizes,
        "num_layers": len(layer_sizes)
    }

In [8]:
def calculate_perplexity(model, input_ids, labels=None):
    if labels is None:
        labels = input_ids.clone()

    with torch.no_grad():
        outputs = model(input_ids, labels=labels)
        neg_log_likelihood = outputs.loss

    return torch.exp(neg_log_likelihood).item()


In [12]:
scores = []

In [13]:
import time
import torch

def evaluate_conversational_performance(model, conversations, tokenizer, disable_updates=True):
    """
    Evaluate agent performance on multi-turn conversations with persistent KV cache.

    Args:
        model: The language model
        agent: The RL agent
        conversations: List of conversation lists, where each conversation is a list of prompts
        tokenizer: Tokenizer for the model
        disable_updates: Whether to disable policy updates during evaluation

    Returns:
        Dictionary with performance metrics
    """
    # Metrics to track
    results = {
        "total_tokens": 0,
        "total_time": 0,
        "perplexities": [],
        "memory_usage": [],
        "tokens_per_second": [],
        "bytes_per_token": [],
        "cache_growth_rate": [],
        "response_quality": []
    }

    for conv_idx, conversation in enumerate(conversations):
        print(f"\nEvaluating conversation {conv_idx+1}/{len(conversations)}")

        # Reset for new conversation
        context = ""
        last_cache_size = 0
        memory_trajectory = []

        for turn_idx, prompt in enumerate(conversation):
            print(f"  Turn {turn_idx+1}/{len(conversation)}")

            # Add the new prompt to context
            if turn_idx > 0:
                context += f"\n\nHuman: {prompt}\nAssistant: "
            else:
                context = f"Human: {prompt}\nAssistant: "

            # Tokenize context
            inputs = tokenizer(context, return_tensors="pt").to(model.device)

            # Generate continuation
            start_time = time.time()
            with torch.no_grad():
                outputs = model.generate(
                    inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    max_new_tokens=100,
                    use_cache=True,
                    return_dict_in_generate=True,
                    output_scores=True
                )
            generation_time = time.time() - start_time

            # Get metrics
            generated_seq = outputs.sequences[0]
            tokens_generated = len(generated_seq) - len(inputs.input_ids[0])
            results["total_tokens"] += tokens_generated
            results["total_time"] += generation_time

            # Access KV cache
            kv_cache = outputs.past_key_values
            cache_info = calculate_dynamic_cache_size(kv_cache)
            current_cache_size = cache_info["total_size_bytes"]
            memory_trajectory.append(current_cache_size)

            # Track cache growth
            if turn_idx > 0:
                cache_growth = (current_cache_size - last_cache_size) / tokens_generated
                results["cache_growth_rate"].append(cache_growth)
            last_cache_size = current_cache_size

            # Update metrics
            results["memory_usage"].append(current_cache_size)
            results["bytes_per_token"].append(current_cache_size / len(generated_seq))
            results["tokens_per_second"].append(tokens_generated / generation_time)

            # Decode response
            generated_text = tokenizer.decode(
                generated_seq[len(inputs.input_ids[0]):],
                skip_special_tokens=True
            )

            score = calculate_quality(generated_text, prompt)
            if type(score) == int:
              scores.append(score)

            # Update context with generated text
            context += generated_text

            # Evaluate quality (optional - can be subjective)
            quality_score = evaluate_response_quality(generated_text, prompt)
            results["response_quality"].append(quality_score)

            # Calculate perplexity on context
            try:
                perplexity = calculate_perplexity(model, inputs.input_ids)
                results["perplexities"].append(perplexity)
            except:
                pass  # Skip if calculation fails

            # Print stats for this turn
            print(f"    Generated {tokens_generated} tokens in {generation_time:.2f}s")
            print(f"    KV Cache: {current_cache_size / (1024*1024):.2f} MB")
            print(f"    Response quality score: {quality_score}")


        # Calculate and visualize memory trajectory for conversation
        plot_memory_trajectory(memory_trajectory, conv_idx)

    # Calculate summary metrics
    results["avg_perplexity"] = sum(results["perplexities"]) / len(results["perplexities"]) if results["perplexities"] else 0
    results["avg_response_quality"] = sum(results["response_quality"]) / len(results["response_quality"]) if results["response_quality"] else 0
    results["avg_memory_usage_mb"] = sum(results["memory_usage"]) / len(results["memory_usage"]) / (1024*1024) if results["memory_usage"] else 0
    results["avg_tokens_per_second"] = sum(results["tokens_per_second"]) / len(results["tokens_per_second"]) if results["tokens_per_second"] else 0
    results["avg_bytes_per_token"] = sum(results["bytes_per_token"]) / len(results["bytes_per_token"]) if results["bytes_per_token"] else 0

    # Print overall summary
    print("\nEvaluation Summary:")
    print(f"Total tokens generated: {results['total_tokens']}")
    print(f"Average perplexity: {results['avg_perplexity']:.2f}")
    print(f"Average response quality: {results['avg_response_quality']:.2f}/10")
    print(f"Average KV cache size: {results['avg_memory_usage_mb']:.2f} MB")
    print(f"Average tokens per second: {results['avg_tokens_per_second']:.2f}")
    print(f"Average bytes per token: {results['avg_bytes_per_token']:.2f}")

    # Print overall action distribution
    print("\nOverall Action Distribution:")
    total_actions = sum(results["action_distribution"].values())
    for action, count in sorted(results["action_distribution"].items()):
        action_name = ["Full Precision", "Half-Precision", "Small Block Eviction",
                      "Large Block Eviction"][int(action)]
        percentage = count / total_actions * 100 if total_actions else 0
        print(f"{action_name}: {count} times ({percentage:.1f}%)")

    # Generate visualizations
    plot_action_distribution(results["action_distribution"])
    plot_memory_vs_turns(results["memory_usage"])

    return results

def evaluate_response_quality(response, prompt):
    """
    Evaluate the quality of a model response.
    This can be implemented in different ways:
    1. Simple heuristics (length, diversity)
    2. Model-based evaluation using another LLM
    3. Human ratings if available

    Returns a score from 0-10
    """
    # Simple implementation - can be replaced with more sophisticated metrics
    # For now, let's use a combination of length and diversity

    # Length normalization (0-5 points)
    length_score = min(5, len(response.split()) / 20)

    # Diversity - unique words ratio (0-3 points)
    words = response.lower().split()
    unique_ratio = len(set(words)) / max(1, len(words))
    diversity_score = 3 * unique_ratio

    # Relevance to prompt (0-2 points) - simple keyword matching
    prompt_words = set(prompt.lower().split())
    overlap = len(prompt_words.intersection(set(words))) / max(1, len(prompt_words))
    relevance_score = 2 * overlap

    return min(10, length_score + diversity_score + relevance_score)

def plot_memory_trajectory(memory_trajectory, conversation_id):
    """Plot memory usage over conversation turns"""
    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 6))
    plt.plot(range(len(memory_trajectory)),
             [m/(1024*1024) for m in memory_trajectory],
             marker='o', linestyle='-')

    plt.xlabel('Conversation Turn')
    plt.ylabel('KV Cache Size (MB)')
    plt.title(f'KV Cache Growth for Conversation {conversation_id+1}')
    plt.grid(True)
    plt.savefig(f'conversation_{conversation_id+1}_memory.png')
    plt.close()

def plot_action_distribution(action_counts):
    """Plot distribution of actions taken by the agent"""
    import matplotlib.pyplot as plt

    labels = ["Full Precision", "Half-Precision", "Small Block Eviction", "Large Block Eviction"]
    counts = [action_counts.get(i, 0) for i in range(4)]

    plt.figure(figsize=(10, 6))
    plt.bar(labels, counts, color=['blue', 'green', 'orange', 'red'])
    plt.ylabel('Count')
    plt.title('Action Distribution')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig('action_distribution.png')
    plt.close()

def plot_memory_vs_turns(memory_usage):
    """Plot overall memory usage pattern"""
    import matplotlib.pyplot as plt

    mb_usage = [m/(1024*1024) for m in memory_usage]
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(mb_usage)), mb_usage, marker='o')
    plt.xlabel('Generation Step')
    plt.ylabel('KV Cache Size (MB)')
    plt.title('KV Cache Size Throughout Evaluation')
    plt.grid(True)
    plt.savefig('memory_usage.png')
    plt.close()

# Example usage
test_conversations = [
    # First conversation - science fiction story
    [
        "I want to write a science fiction story. Can you help me brainstorm some ideas?",
        "I like the idea about a planet with unusual crystal formations. Tell me more about this setting.",
        "How might humans adapt to living in this environment?",
        "What kind of conflicts could arise in this setting?",
        "Can you summarize the key elements of this story concept?"
    ],

    # Second conversation - technical explanation
    [
        "Explain how neural networks work.",
        "What's the difference between CNN and RNN?",
        "How does backpropagation actually work?",
        "Can you give me some practical applications of these concepts?",
        "Summarize what we've discussed about neural networks."
    ]
]

eval_results = evaluate_conversational_performance(model, test_conversations, tokenizer)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Evaluating conversation 1/2
  Turn 1/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Quality 6
    Generated 100 tokens in 3.74s
    KV Cache: 26.69 MB
    Response quality score: 8.386363636363637
  Turn 2/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Quality 8
    Generated 100 tokens in 3.77s
    KV Cache: 54.03 MB
    Response quality score: 7.0882530120481935
  Turn 3/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Quality 7
    Generated 100 tokens in 3.75s
    KV Cache: 79.19 MB
    Response quality score: 7.1545977011494255
  Turn 4/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Quality 8
    Generated 100 tokens in 3.74s
    KV Cache: 104.56 MB
    Response quality score: 7.744444444444444
  Turn 5/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Invalid response: _______


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


    Generated 100 tokens in 3.85s
    KV Cache: 130.16 MB
    Response quality score: 7.035185185185185

Evaluating conversation 2/2
  Turn 1/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Quality 8
    Generated 100 tokens in 3.69s
    KV Cache: 24.28 MB
    Response quality score: 7.0874999999999995
  Turn 2/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Quality 8
    Generated 100 tokens in 3.71s
    KV Cache: 49.66 MB
    Response quality score: 6.535501066098081
  Turn 3/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Invalid response: _______
    Generated 100 tokens in 3.75s
    KV Cache: 74.38 MB
    Response quality score: 6.5675675675675675
  Turn 4/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Quality 8
    Generated 100 tokens in 3.79s
    KV Cache: 99.97 MB
    Response quality score: 6.827272727272727
  Turn 5/5


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Quality 7
    Generated 100 tokens in 3.87s
    KV Cache: 125.56 MB
    Response quality score: 6.195714285714286

Evaluation Summary:
Total tokens generated: 1000
Average perplexity: 23.88
Average response quality: 7.06/10
Average KV cache size: 76.85 MB
Average tokens per second: 26.55
Average bytes per token: 228484.82

Overall Action Distribution:


KeyError: 'action_distribution'