In [1]:
import transformers
import torch
import time
import shutil
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Load the model
ckpt = "models/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
).to("cuda")


generation_config = GenerationConfig.from_pretrained(ckpt)
eos_token_ids = generation_config.eos_token_id
if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

# add some tokens like "</user>" and </s> to eos ids
eos_token_ids += tokenizer.encode("</user>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</s>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</", add_special_tokens=False)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:

context = "A quick brown fox jumps over the lazy dog. \n"
# with open("demo/duo_attention.txt", "r") as f:
#     needle = f.read()
needle = "Mary's favorite number is 34251"
num_tokens_context = len(tokenizer.encode(context, add_special_tokens=False))
num_repetitions = 10000 // num_tokens_context

text = (
    "This is a very long story book: <book> "
    + context * int(num_repetitions * 0.75)
    + needle
    + context * int(num_repetitions * (1 - 0.75))
    + "</book>\n Based on the content of the book, please briefly tell me about what is Mary's favorite number.\nAnswer:"
)

input_ids = tokenizer.encode(text, return_tensors="pt").to("cuda")

print(f"Input sequence length: {input_ids.size(1)}\n")

Input sequence length: 10033



In [3]:

from transformers.cache_utils import Cache, StaticCache, OffloadedCache, OffloadedStaticCache

from typing import Any, Dict, List, Optional, Tuple, Union

class DynamicCache(Cache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
        >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

        >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> past_key_values = DynamicCache()
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        DynamicCache()
        ```
    """

    def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
        super().__init__()
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def __iter__(self):
        """
        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
        keys and values
        """
        for layer_idx in range(len(self)):
            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

    def __len__(self):
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return len(self.key_cache)

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) <= layer_idx:
            # There may be skipped layers, fill them with empty lists
            for _ in range(len(self.key_cache), layer_idx):
                self.key_cache.append([])
                self.value_cache.append([])
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors
            self.key_cache[layer_idx] = key_states
            self.value_cache[layer_idx] = value_states
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        key, value = self.key_cache[layer_idx], self.value_cache[layer_idx]
        incoming_kv_seq_len = self.key_cache[layer_idx].shape[2]
        if incoming_kv_seq_len > self.sink_size + self.recent_size:
            sink_key_states = self.key_cache[layer_idx][:, :, : self.sink_size, :]
            recent_key_states = self.key_cache[layer_idx][
                :, :, incoming_kv_seq_len - self.recent_size : incoming_kv_seq_len, :
            ]
            self.key_cache[layer_idx] = torch.cat([sink_key_states, recent_key_states], dim=-2)

            sink_value_states = self.value_cache[layer_idx][:, :, : self.sink_size, :]
            recent_value_states = self.value_cache[layer_idx][
                :, :, incoming_kv_seq_len - self.recent_size : incoming_kv_seq_len, :
            ]
            self.value_cache[layer_idx] = torch.cat([sink_value_states, recent_value_states], dim=-2)

        return key, value

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # TODO: deprecate this function in favor of `cache_position`
        is_empty_layer = (
            len(self.key_cache) == 0  # no cache in any layer
            or len(self.key_cache) <= layer_idx  # skipped `layer_idx` and hasn't run a layer with cache after it
            or len(self.key_cache[layer_idx]) == 0  # the layer has no cache
        )
        layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
        return layer_seq_length

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        return None

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
        backward compatibility."""
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
        return legacy_cache

    @classmethod
    def from_legacy_cache(
        cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
    ) -> "DynamicCache":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
        backward compatibility."""
        cache = cls()
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache

    def crop(self, max_length: int):
        """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
        negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
        # In case it is negative
        if max_length < 0:
            max_length = self.get_seq_length() - abs(max_length)

        if self.get_seq_length() <= max_length:
            return

        self._seen_tokens = max_length
        for idx in range(len(self.key_cache)):
            if self.key_cache[idx] != []:
                self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
                self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]

    def batch_split(
        self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
    ) -> List["DynamicCache"]:
        """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
        `_split_model_inputs()` in `generation.utils`"""
        out = []
        for i in range(0, full_batch_size, split_size):
            current_split = DynamicCache()
            current_split._seen_tokens = self._seen_tokens
            current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
            current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
            out.append(current_split)
        return out

    @classmethod
    def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
        """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
        `generation.utils`"""
        cache = cls()
        for idx in range(len(splits[0])):
            key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
            value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
            if key_cache != []:
                layer_keys = torch.cat(key_cache, dim=0)
                layer_values = torch.cat(value_cache, dim=0)
                cache.update(layer_keys, layer_values, idx)
        return cache

    def batch_repeat_interleave(self, repeats: int):
        """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
        for layer_idx in range(len(self)):
            self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
            self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)

    def batch_select_indices(self, indices: torch.Tensor):
        """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
        for layer_idx in range(len(self)):
            self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
            self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]

In [4]:

# Manually perform inference using KV cache

# inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
max_new_tokens = 500
generated_tokens = []
# input_ids = inputs["input_ids"]

# Initialize past_key_values to None
past_key_values = DynamicCache()
past_key_values.sink_size = 64
past_key_values.recent_size = 256


for _ in range(max_new_tokens):
    with torch.no_grad():
        outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
        
        # Extract the logits and past_key_values (the cache)
        next_token_logits = outputs.logits[:, -1, :]  # Logits of the last token
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step


        # Greedy decoding: get the token with the highest probability
        next_token = torch.argmax(next_token_logits, dim=-1)
        if next_token.item() in eos_token_ids:
            break
        generated_tokens.append(next_token.item())

        # Only pass the new token for the next iteration
        input_ids = next_token.unsqueeze(-1)

# Convert generated token ids to text
output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
print(output_text)

 Mary had a little lamb. 
Mary had a little lamb. Its fleece of white lamb. Its fleece as white as white as white as white as snow. And Mary's fleece. And Mary's little lamb. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And one. And. And. And. And. And. And. And. And. And. And. And. And. 