In [1]:
"""
This script is adapted from 
https://github.com/gkamradt/LLMTest_NeedleInAHaystack
"""

import os
import glob
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

import numpy as np
import argparse
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)

from datetime import datetime, timezone
import time
import torch

In [2]:
context_lengths_min=120000
context_lengths_max=1048000
context_lengths_num_intervals=40
pretrained_len=1048000
sparsity=0.5
document_depth_percent_min=0
document_depth_percent_max=100
document_depth_percent_intervals=10
document_depth_percent_interval_type="linear"
final_context_length_buffer=200
simulation_length=50
prefilling_chunk_size=32000
prefilling_chunk_size = None

needle="\n\nRemember, the best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n\n"
retrieval_question="what is the best thing to do in San Francisco?\n\nAnswer: The best thing to do in San Francisco is"
haystack_dir="eval/needle/PaulGrahamEssays"
testing_results = []

context_lengths = np.round(
    np.linspace(
        context_lengths_min,
        context_lengths_max,
        num=context_lengths_num_intervals,
        endpoint=True,
    )
).astype(int)

if document_depth_percent_interval_type == "linear":
    document_depth_percents = np.round(
        np.linspace(
            document_depth_percent_min,
            document_depth_percent_max,
            num=document_depth_percent_intervals,
            endpoint=True,
        )
    ).astype(int)
elif document_depth_percent_interval_type == "sigmoid":
    document_depth_percents = [
        logistic(x)
        for x in np.linspace(
            document_depth_percent_min,
            document_depth_percent_max,
            document_depth_percent_intervals,
        )
    ]

model_name = "models/Llama-3-8B-Instruct-Gradient-1048k"
model_to_test_description = model_name
enc = AutoTokenizer.from_pretrained(model_name, use_fast=False)
generation_config = GenerationConfig.from_pretrained(model_name)
eos_token_ids = generation_config.eos_token_id

if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

if enc.pad_token_id is None:
    if enc.eos_token_id is not None:
        enc.pad_token_id = enc.eos_token_id
    else:
        enc.pad_token_id = 0
print("Loading from %s" % model_name)

model_to_test = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
).eval()

Loading from models/Llama-3-8B-Instruct-Gradient-1048k


You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
from duo_attn.utils import load_attn_pattern, sparsify_attention_heads
from duo_attn.patch import enable_duo_attention_eval

# Load the attention pattern
attn_heads, sink_size, recent_size = load_attn_pattern(
    "attn_patterns/Llama-3-8B-Instruct-Gradient-1048k/lr=0.02-reg=0.05-ctx=1000_32000-multi_passkey10"
)
model = model_to_test
print(attn_heads.shape)
print(sink_size)
print(recent_size)

# Sparsify attention heads

print(attn_heads, sparsity)
attn_heads, sparsity = sparsify_attention_heads(attn_heads, sparsity=0.5)

print(attn_heads, sparsity)

(32, 8)
128
256
[[8.59375000e-01 6.52343750e-01 1.00000000e+00 3.39843750e-01
  0.00000000e+00 6.79687500e-01 3.49609375e-01 2.73437500e-01]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00
  0.00000000e+00 8.24218750e-01 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 6.91406250e-01 0.00000000e+00 9.60937500e-01
  7.30468750e-01 9.84375000e-01 9.64843750e-01 0.00000000e+00]
 [7.85156250e-01 2.30073929e-05 3.83853912e-05 0.00000000e+00
  8.71093750e-01 9.76562500e-01 1.00000000e+00 4.21875000e-01]
 [0.00000000e+00 0.00000000e+00 1.17675781e-01 8.63281250e-01
  1.00000000e+00 9.88281250e-01 8.35937500e-01 2.61306763e-04]
 [7.89062500e-01 9.84375000e-01 1.00000000e+00 5.82031250e-01
  1.00000000e+00 9.45312500e-01 9.52148438e-02 1.00000000e+00]
 [5.00000000e-01 8.63281250e-01 7.03125000e-01 7.18750000e-01
  9.10156250e-01 1.00000000e+00 1.00000000e+00 8.94531250e-01]
 [1.00000000e+00 8.32031250e-01 1.44042969e-02 1.00000000e+00
  6.10351562e-05 1.00000000e+00 1.00000000e+00 

In [4]:
from transformers.models.llama.modeling_llama import (
    logger,
    apply_rotary_pos_emb,
    repeat_kv,
    LlamaSdpaAttention,
    LlamaFlashAttention2,
)
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache
import types
from transformers.modeling_flash_attention_utils  import _flash_attention_forward
from flash_attn import flash_attn_func

def LlamaAttention_fast_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) :
    output_attentions = False

    bsz, q_len, hd = hidden_states.size()
    chunk_size = hd // self.num_key_value_heads
    num_heads = self.num_heads // self.num_key_value_heads

    if not hasattr(self, 'q_proj_list'):
        self.q_proj_list = list((self.q_proj.weight.split(self.head_dim * num_heads, dim=0)))
        # self.q_proj.weight.data.storage().resize_(0)
    if not hasattr(self, 'k_proj_list'):
        self.k_proj_list = list((self.k_proj.weight.split(self.head_dim, dim=0)))
        # self.k_proj.weight.data.storage().resize_(0)
    if not hasattr(self, 'v_proj_list'):
        self.v_proj_list = list((self.v_proj.weight.split(self.head_dim, dim=0)))
        # self.v_proj.weight.data.storage().resize_(0)


    attn_output_list = [None for _ in range((self.num_key_value_heads))]
    
    for i in range(self.num_key_value_heads):
        bsz, q_len, hd = hidden_states.size()

        self.q_proj.weight.data = self.q_proj_list[i].data
        self.k_proj.weight.data = self.k_proj_list[i].data
        self.v_proj.weight.data = self.v_proj_list[i].data

        # print(hidden_states.shape, self.q_proj.weight.shape)
        
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, 1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, 1, self.head_dim).transpose(1, 2)
    
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    
        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, 'full_head': self.full_attn_head_mask[i]}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx + i, cache_kwargs)
    
        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
        # to be able to avoid many of these transpose/reshape/view.
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)
    
        dropout_rate = self.attention_dropout if self.training else 0.0
    
        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)
    
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype
    
            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )
    
            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)
    
        attn_output = flash_attn_func(
            query_states,
            key_states,
            value_states,
            causal=True,
            dropout_p=0.0,
        )
    
        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output_list[i] = attn_output
        
    attn_output = torch.cat(attn_output_list, dim=-1)
    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

full_attention_heads = attn_heads
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
layer_idx = 0
for idx, layer in enumerate(model.model.layers):
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype
    module = layer.self_attn
    module.layer_idx = layer_idx
    layer_idx += module.num_key_value_heads

    # print(full_attention_heads)
    # layer_full_attention_heads = torch.tensor(
    #     full_attention_heads[idx], device=device, dtype=dtype
    # )
    # threshold = np.quantile(full_attention_heads, 0.5)
    module.full_attn_head_mask = full_attention_heads[idx] >= 0.5
    module.num_full_attn_head = module.full_attn_head_mask.sum().item()

    print(module.full_attn_head_mask)


[False False  True False False False False False]
[False False False  True False False False False]
[False False False False False  True  True False]
[False False False False False  True  True False]
[False False False False  True  True False False]
[False  True  True False  True False False  True]
[False False False False False  True  True False]
[ True False False  True False  True  True False]
[ True  True  True False  True False  True  True]
[False False False  True False  True  True  True]
[ True False False  True  True False  True  True]
[ True  True False False False False False  True]
[ True False False False False  True False False]
[ True  True  True False  True False  True  True]
[False  True False False  True  True  True  True]
[ True False  True False  True False  True  True]
[ True  True  True False False  True  True False]
[ True  True False  True False  True  True  True]
[False  True False  True False  True False False]
[ True False  True  True False  True False  True]


In [5]:

class DynamicCache(Cache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
        >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

        >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> past_key_values = DynamicCache()
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        DynamicCache()
        ```
    """

    def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
        super().__init__()
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def __iter__(self):
        """
        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
        keys and values
        """
        for layer_idx in range(len(self)):
            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

    def __len__(self):
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return len(self.key_cache)

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) <= layer_idx:
            # There may be skipped layers, fill them with empty lists
            for _ in range(len(self.key_cache), layer_idx):
                self.key_cache.append([])
                self.value_cache.append([])
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors
            self.key_cache[layer_idx] = key_states
            self.value_cache[layer_idx] = value_states
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        
        if cache_kwargs['full_head'] == False:
            key, value = self.key_cache[layer_idx], self.value_cache[layer_idx]
            incoming_kv_seq_len = self.key_cache[layer_idx].shape[2]
            if incoming_kv_seq_len > self.sink_size + self.recent_size:
                sink_key_states = self.key_cache[layer_idx][:, :, : self.sink_size, :].clone()
                recent_key_states = self.key_cache[layer_idx][
                    :, :, incoming_kv_seq_len - self.recent_size : incoming_kv_seq_len, :
                ].clone()
                self.key_cache[layer_idx] = torch.cat([sink_key_states, recent_key_states], dim=-2)

                sink_value_states = self.value_cache[layer_idx][:, :, : self.sink_size, :].clone()
                recent_value_states = self.value_cache[layer_idx][
                    :, :, incoming_kv_seq_len - self.recent_size : incoming_kv_seq_len, :
                ].clone()
                self.value_cache[layer_idx] = torch.cat([sink_value_states, recent_value_states], dim=-2)
        else:
            key, value = self.key_cache[layer_idx], self.value_cache[layer_idx]

        return key, value

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # TODO: deprecate this function in favor of `cache_position`
        is_empty_layer = (
            len(self.key_cache) == 0  # no cache in any layer
            or len(self.key_cache) <= layer_idx  # skipped `layer_idx` and hasn't run a layer with cache after it
            or len(self.key_cache[layer_idx]) == 0  # the layer has no cache
        )
        layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
        return layer_seq_length

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        return None

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
        backward compatibility."""
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
        return legacy_cache

    @classmethod
    def from_legacy_cache(
        cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
    ) -> "DynamicCache":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
        backward compatibility."""
        cache = cls()
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache

    def crop(self, max_length: int):
        """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
        negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
        # In case it is negative
        if max_length < 0:
            max_length = self.get_seq_length() - abs(max_length)

        if self.get_seq_length() <= max_length:
            return

        self._seen_tokens = max_length
        for idx in range(len(self.key_cache)):
            if self.key_cache[idx] != []:
                self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
                self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]

    def batch_split(
        self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
    ) -> List["DynamicCache"]:
        """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
        `_split_model_inputs()` in `generation.utils`"""
        out = []
        for i in range(0, full_batch_size, split_size):
            current_split = DynamicCache()
            current_split._seen_tokens = self._seen_tokens
            current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
            current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
            out.append(current_split)
        return out

    @classmethod
    def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
        """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
        `generation.utils`"""
        cache = cls()
        for idx in range(len(splits[0])):
            key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
            value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
            if key_cache != []:
                layer_keys = torch.cat(key_cache, dim=0)
                layer_values = torch.cat(value_cache, dim=0)
                cache.update(layer_keys, layer_values, idx)
        return cache

    def batch_repeat_interleave(self, repeats: int):
        """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
        for layer_idx in range(len(self)):
            self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
            self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)

    def batch_select_indices(self, indices: torch.Tensor):
        """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
        for layer_idx in range(len(self)):
            self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
            self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]

In [6]:

class OffloadedCache(DynamicCache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
        >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

        >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> past_key_values = DynamicCache()
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        DynamicCache()
        ```
    """

    def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
        if not torch.cuda.is_available():
            raise RuntimeError("OffloadedCache can only be used with a GPU")
        super().__init__()
        self.original_device = []
        self.prefetch_stream = torch.cuda.Stream()
        self.beam_idx = None  # used to delay beam search operations

        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        
        self.offload_key_cache: List[torch.Tensor] = []
        self.offload_value_cache: List[torch.Tensor] = []

        self.id_type_list = []
        self.real_id_dict = {}

    def prefetch_layer(self, layer_idx: int):
        "Starts prefetching the next layer cache"
        if layer_idx < len(self.offload_key_cache):
            with torch.cuda.stream(self.prefetch_stream):
                # Prefetch next layer tensors to GPU
                device = self.original_device[layer_idx]
                self.offload_key_cache[layer_idx] = self.offload_key_cache[layer_idx].to(device, non_blocking=True)
                self.offload_value_cache[layer_idx] = self.offload_value_cache[layer_idx].to(device, non_blocking=True)
    
    def evict_previous_layer(self, layer_idx: int):
        "Moves the previous layer cache to the CPU"
        if len(self.offload_key_cache) > 2:
            # We do it on the default stream so it occurs after all earlier computations on these tensors are done
            prev_layer_idx = (layer_idx - 1) % len(self.offload_key_cache)
            self.offload_key_cache[prev_layer_idx] = self.offload_key_cache[prev_layer_idx].to("cpu", non_blocking=True)
            self.offload_value_cache[prev_layer_idx] = self.offload_value_cache[prev_layer_idx].to("cpu", non_blocking=True)
            
    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self.id_type_list):
            if self.id_type_list[layer_idx] == True:
                layer_idx = self.real_id_dict[layer_idx]
                return (self.key_cache[layer_idx], self.value_cache[layer_idx])
            else:
                layer_idx = self.real_id_dict[layer_idx]
                # Evict the previous layer if necessary
                torch.cuda.current_stream().synchronize()
                self.evict_previous_layer(layer_idx)
                # Load current layer cache to its original device if not already there
                original_device = self.original_device[layer_idx]
                self.prefetch_stream.synchronize()
                key_tensor = self.offload_key_cache[layer_idx]
                value_tensor = self.offload_value_cache[layer_idx]
                # Now deal with beam search ops which were delayed
                if self.beam_idx is not None:
                    self.beam_idx = self.beam_idx.to(original_device)
                    key_tensor = key_tensor.index_select(0, self.beam_idx)
                    value_tensor = value_tensor.index_select(0, self.beam_idx)
                # Prefetch the next layer
                self.prefetch_layer((layer_idx + 1) % len(self))
                return (key_tensor, value_tensor)
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Saves the beam indices and reorders the cache when the tensor is back to its device."""
        # We delay this operation until the tensors are back to their original
        # device because performing torch.index_select on the CPU is very slow
        del self.beam_idx
        self.beam_idx = beam_idx.clone()

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        is_stream_head = (cache_kwargs['full_head'] == False)
        if len(self.id_type_list) <= layer_idx:
            self.id_type_list.append(is_stream_head)

        original_layer_layer_idx = layer_idx

        # print(layer_idx, cache_kwargs['full_head'])

        if is_stream_head:
            if layer_idx in self.real_id_dict:
                layer_idx = self.real_id_dict[layer_idx]
            else:
                self.real_id_dict[layer_idx] = len(self.key_cache)
                layer_idx = self.real_id_dict[layer_idx]
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append([])
                    self.value_cache.append([])
                self.key_cache.append(key_states)
                self.value_cache.append(value_states)
            elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
                
            key_tensor, value_tensor = self.key_cache[layer_idx], self.value_cache[layer_idx]
            incoming_kv_seq_len = self.key_cache[layer_idx].shape[2]
            if incoming_kv_seq_len > self.sink_size + self.recent_size:
                sink_key_states = self.key_cache[layer_idx][:, :, : self.sink_size, :]
                recent_key_states = self.key_cache[layer_idx][
                    :, :, incoming_kv_seq_len - self.recent_size : incoming_kv_seq_len, :
                ]
                self.key_cache[layer_idx] = torch.cat([sink_key_states, recent_key_states], dim=-2)

                sink_value_states = self.value_cache[layer_idx][:, :, : self.sink_size, :]
                recent_value_states = self.value_cache[layer_idx][
                    :, :, incoming_kv_seq_len - self.recent_size : incoming_kv_seq_len, :
                ]
                self.value_cache[layer_idx] = torch.cat([sink_value_states, recent_value_states], dim=-2)
        else:
            if layer_idx in self.real_id_dict:
                layer_idx = self.real_id_dict[layer_idx]
            else:
                self.real_id_dict[layer_idx] = len(self.offload_key_cache)
                layer_idx = self.real_id_dict[layer_idx]
                
            if len(self.offload_key_cache) < layer_idx:
                raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
            elif len(self.offload_key_cache) == layer_idx:
                self.offload_key_cache.append(key_states)
                self.offload_value_cache.append(value_states)
                self.original_device.append(key_states.device)
                self.evict_previous_layer(layer_idx)
                key_tensor, value_tensor = key_states, value_states
            else:
                key_tensor, value_tensor = self[original_layer_layer_idx]
                self.offload_key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
                self.offload_value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
                key_tensor = self.offload_key_cache[layer_idx]
                value_tensor = self.offload_value_cache[layer_idx]

        return key_tensor, value_tensor
        

In [7]:
model_to_test = model_to_test.cuda()

KeyboardInterrupt: 

In [None]:
print("\n")
print("Starting Needle In A Haystack Testing...")
print(f"- Model: {model_name}")
print(
    f"- Context Lengths: {len(context_lengths)}, Min: {min(context_lengths)}, Max: {max(context_lengths)}"
)
print(
    f"- Document Depths: {len(document_depth_percents)}, Min: {min(document_depth_percents)}%, Max: {max(document_depth_percents)}%"
)
print(f"- Needle: {needle.strip()}")
print("\n\n")

In [None]:

def get_context_length_in_tokens(context):
    return len(enc.encode(context))
    
def read_context_files():
    context = ""
    max_context_length = max(context_lengths)
    while get_context_length_in_tokens(context) < max_context_length:
        for file in glob.glob(f"{haystack_dir}/*.txt"):
            with open(file, "r") as f:
                context += f.read()
    return context
    
def get_tokens_from_context(context):
    return enc.encode(context)
    
def decode_tokens(tokens, context_length=None):
    return enc.decode(tokens[:context_length], skip_special_tokens=True)
    
def encode_and_trim(context, context_length):
    tokens = get_tokens_from_context(context)
    if len(tokens) > context_length:
        context = decode_tokens(tokens, context_length)
    return context

def encode_text_to_tokens(text):
    return enc.encode(text, add_special_tokens=False)
    
def insert_needle(context, depth_percent, context_length):
    tokens_needle = encode_text_to_tokens(needle)
    tokens_context = encode_text_to_tokens(context)

    # Reducing the context length by 150 buffer. This is to account for system message, the user question, and response.
    context_length -= final_context_length_buffer

    # If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length
    if len(tokens_context) + len(tokens_needle) > context_length:
        tokens_context = tokens_context[: context_length - len(tokens_needle)]

    if depth_percent == 100:
        # If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end
        tokens_new_context = tokens_context + tokens_needle
    else:
        insertion_point = int(len(tokens_context) * (depth_percent / 100))

        tokens_new_context = tokens_context[:insertion_point]

        tokens_new_context += tokens_needle + tokens_context[insertion_point:]

    # Convert back to a string and return it
    new_context = decode_tokens(tokens_new_context)
    return new_context
    
def generate_context(context_length, depth_percent):
    # Load up tiktoken so we navigate tokens more easily

    # Get your Paul Graham files loaded into a string
    context = read_context_files()

    # Truncate the Paul Graham essays to the context length you desire
    context = encode_and_trim(context, context_length)

    # Insert your random statement according to your depth percent
    context = insert_needle(context, depth_percent, context_length)

    return context
    
def generate_prompt(context):
    test_format = f"<|im_start|> This is a very long story book: <book> {context} </book>.\n\nQuestion: Based on the content of the book, {retrieval_question}"
    return test_format
    
def bound_evaluate_and_log(context_length, depth_percent):
    # Go generate the required length context and place your needle statement in
    context = generate_context(context_length, depth_percent)
    
    # Prepare your message to send to the model you're going to evaluate
    prompt = generate_prompt(context)

    generated_prompt = prompt

    test_start_time = time.time()

    # Simulate multiround conversation
    prompt = enc(prompt, return_tensors="pt")

    prompt_input_ids = prompt["input_ids"].to(model_to_test.device)

    # simulation_start_idx = prompt_input_ids.size(1) - simulation_length

    # question_input_ids = prompt_input_ids[:, simulation_start_idx:]
    # prompt_input_ids = prompt_input_ids[:, :simulation_start_idx]

    with torch.no_grad():
        past_key_values = OffloadedCache()
        past_key_values.sink_size = 64
        past_key_values.recent_size = 256
        if prefilling_chunk_size is not None:
            for i in range(
                0, prompt_input_ids.size(1), prefilling_chunk_size
            ):
                chunk = prompt_input_ids[:, i : i + prefilling_chunk_size]
                output = model_to_test(
                    input_ids=chunk,
                    past_key_values=past_key_values,
                    use_cache=True,
                    num_logits_to_keep=1
                )
                past_key_values = output.past_key_values
        else:
            output = model_to_test(
                input_ids=prompt_input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1
            )
            past_key_values = output.past_key_values

        # for input_id in question_input_ids[0]:
        #     output = model_to_test(
        #         input_ids=input_id.unsqueeze(0).unsqueeze(0),
        #         past_key_values=past_key_values,
        #         use_cache=True,
        #     )
        #     past_key_values = output.past_key_values

        # for idx, layer in enumerate(model.model.layers):
        #     device = next(model.parameters()).device
        #     dtype = next(model.parameters()).dtype
        #     module = layer.self_attn
            
        #     module.full_attn_head_mask = full_attention_heads[idx] >= 0.75
        
        pred_token_idx = output.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        generated_content = [pred_token_idx.item()]
        for _ in range(50):
            outputs = model_to_test(
                input_ids=pred_token_idx,
                past_key_values=past_key_values,
                use_cache=True,
                num_logits_to_keep=1,
            )

            past_key_values = outputs.past_key_values
            pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
            generated_content += [pred_token_idx.item()]
            if pred_token_idx.item() in eos_token_ids:
                break

    response = enc.decode(generated_content, skip_special_tokens=True).strip()
    
    test_end_time = time.time()
    test_elapsed_time = test_end_time - test_start_time
    score = scorer.score(needle, response)["rouge1"].fmeasure * 10

    
    results = {
        # 'context' : context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
        "model": model_to_test_description,
        "context_length": int(context_length),
        "depth_percent": float(depth_percent),
        "needle": needle,
        "model_response": response,
        "score": score,
        "test_duration_seconds": test_elapsed_time,
        "test_timestamp_utc": datetime.now(timezone.utc).strftime(
            "%Y-%m-%d %H:%M:%S%z"
        ),
    }

    testing_results.append(results)
    print(f"-- Test Summary -- ")
    print(f"Duration: {test_elapsed_time:.1f} seconds")
    print(f"Context: {context_length} tokens")
    print(f"Depth: {depth_percent}%")
    print(f"Score: {score}")
    print(f"Response: {response}\n")

    model_version = model_name.split("/")[-1]
    context_file_location = f'{model_version.replace(".", "_")}_len_{context_length}_depth_{int(depth_percent*100)}'

    results["file_name"] = context_file_location

    # Save the context to file for retesting
    if not os.path.exists("contexts"):
        os.makedirs("contexts")

    if not os.path.exists(f"contexts/{model_version}"):
        os.makedirs(f"contexts/{model_version}")

    with open(
        f"contexts/{model_version}/{context_file_location}_context.txt",
        "w",
        encoding="utf-8",
    ) as f:
        f.write(context)

    # Save the context to file for retesting
    if not os.path.exists("results"):
        os.makedirs("results")

    if not os.path.exists(f"results/{model_version}"):
        os.makedirs(f"results/{model_version}")

    # Save the result to file for retesting
    p = f"results/{model_version}/{context_file_location}_results.json"
    print("Writing at %s" % p)
    print(p)
    with open(p, "w", encoding="utf-8") as f:
        json.dump(results, f)

    return None, generated_prompt

s_len = 1
e_len = pretrained_len
tasks = []
for context_length in context_lengths:
    print(context_length)
    if context_length < s_len or context_length > e_len:
        continue
    for depth_percent in document_depth_percents:
        print(depth_percent)
        task = bound_evaluate_and_log(context_length, depth_percent)

    break