In [1]:
import transformers
import torch
import time
import shutil
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Load the model
ckpt = "models/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
).to("cuda")

generation_config = GenerationConfig.from_pretrained(ckpt)
eos_token_ids = generation_config.eos_token_id
if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

# add some tokens like "</user>" and </s> to eos ids
eos_token_ids += tokenizer.encode("</user>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</s>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</", add_special_tokens=False)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
from transformers.models.llama.modeling_llama import (
    logger,
    apply_rotary_pos_emb,
    repeat_kv,
    LlamaSdpaAttention,
    LlamaFlashAttention2,
    LlamaMLP,
    LlamaRMSNorm
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.modeling_flash_attention_utils  import _flash_attention_forward

def minis_mlp_forward(self, x):
    bsz, q_len, _ = x.size()
    chunk_size = self.hidden_size

    x_list = list(x.split(chunk_size, dim=1))

    output_list = [None for _ in range(len(x_list))]

    for i in range(len(x_list)):
        x = x_list[i]
        x_list[i] = None
        output_list[i] = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

    down_proj = torch.cat(output_list, dim=1)

    return down_proj

def minis_Norm_forward(self, hidden_states):
    
    input_dtype = hidden_states.dtype

    bsz, q_len, _ = hidden_states.size()
    chunk_size = 4096

    x_list = list(hidden_states.split(chunk_size, dim=1))

    output_list = [None for _ in range(len(x_list))]

    
    for i in range(len(x_list)):
        hidden_states = x_list[i].to(torch.float32)
        x_list[i] = None
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        output_list[i] = hidden_states.to(input_dtype)

    output = torch.cat(output_list, dim=1)
    return self.weight * output
    
LlamaMLP.forward = minis_mlp_forward
LlamaRMSNorm.forward = minis_Norm_forward

In [3]:

class OffloadedCache(DynamicCache):
    """
    A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
    Useful for generating from models with very long context.

    In addition to the default CUDA stream, where all forward() computations happen,
    this class uses another stream, the prefetch stream, which it creates itself.
    Since scheduling of operations on separate streams happens independently, this class uses
    the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
    The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
    ensure the eviction is scheduled after all computations on that cache are finished.
    """

    def __init__(self) -> None:
        if not torch.cuda.is_available():
            raise RuntimeError("OffloadedCache can only be used with a GPU")
        super().__init__()
        self.original_device = []
        self.prefetch_stream = torch.cuda.Stream()
        self.beam_idx = None  # used to delay beam search operations

    def prefetch_layer(self, layer_idx: int):
        "Starts prefetching the next layer cache"
        if layer_idx < len(self):
            with torch.cuda.stream(self.prefetch_stream):
                # Prefetch next layer tensors to GPU
                device = self.original_device[layer_idx]
                self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=False)
                self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=False)

    def evict_previous_layer(self, layer_idx: int):
        "Moves the previous layer cache to the CPU"
        if len(self) > 2:
            # We do it on the default stream so it occurs after all earlier computations on these tensors are done
            prev_layer_idx = (layer_idx - 1) % len(self)
            self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=False)
            self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=False)

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
        if layer_idx < len(self):
            # Evict the previous layer if necessary
            torch.cuda.current_stream().synchronize()
            self.evict_previous_layer(layer_idx)
            # Load current layer cache to its original device if not already there
            original_device = self.original_device[layer_idx]
            self.prefetch_stream.synchronize()
            key_tensor = self.key_cache[layer_idx]
            value_tensor = self.value_cache[layer_idx]
            # Now deal with beam search ops which were delayed
            if self.beam_idx is not None:
                self.beam_idx = self.beam_idx.to(original_device)
                key_tensor = key_tensor.index_select(0, self.beam_idx)
                value_tensor = value_tensor.index_select(0, self.beam_idx)
            # Prefetch the next layer
            self.prefetch_layer((layer_idx + 1) % len(self))
            return (key_tensor, value_tensor)
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Saves the beam indices and reorders the cache when the tensor is back to its device."""
        # We delay this operation until the tensors are back to their original
        # device because performing torch.index_select on the CPU is very slow
        del self.beam_idx
        self.beam_idx = beam_idx.clone()

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) < layer_idx:
            raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
        elif len(self.key_cache) == layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
            self.original_device.append(key_states.device)
            self.evict_previous_layer(layer_idx)
        else:
            key_tensor, value_tensor = self[layer_idx]
            self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
    # if a method is not supposed to be supported in a subclass we should set it to None
    from_legacy_cache = None

    to_legacy_cache = None

In [4]:
# Manually perform inference using KV cache
input_ids = torch.randint(0, tokenizer.vocab_size, (1, 900000)).to('cuda')
next_input_ids = torch.randint(0, tokenizer.vocab_size, (1, 1)).to('cuda')

max_new_tokens = 1
next_new_tokens = 10
generated_tokens = []

# Initialize past_key_values to None
past_key_values = OffloadedCache()

for epoch in range(max_new_tokens):
    with torch.no_grad():
        start_time = time.time()
        outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)
        
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

        torch.cuda.empty_cache()
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{max_new_tokens} - Time: {epoch_time:.2f} seconds")
        print(
            "Peak allocated bytes on {:4f}GB".format(
                torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2**30
            )
        )

# for epoch in range(next_new_tokens):
#     with torch.no_grad():
#         start_time = time.time()
#         outputs = model(input_ids=next_input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)
        
#         # Extract the logits and past_key_values (the cache)
#         past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

#         torch.cuda.empty_cache()
        
#         epoch_time = time.time() - start_time
#         print(f"Epoch {epoch+1}/{next_new_tokens} - Time: {epoch_time:.2f} seconds")
#         print(
#             "Peak allocated bytes on {:4f}GB".format(
#                 torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2**30
#             )
#         )

OutOfMemoryError: CUDA out of memory. Tried to allocate 6.87 GiB. GPU 0 has a total capacity of 79.26 GiB of which 1.37 GiB is free. Process 467170 has 77.87 GiB memory in use. Of the allocated memory 66.91 GiB is allocated by PyTorch, and 10.45 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# print(len(past_key_values))

In [None]:
# import torch
# import torch.nn.functional as F
# for i in range(len(past_key_values)):
#     print(past_key_values.key_cache[i].shape)

#     for k in range(8):
#         for l in range(8):
#             tensors0 = past_key_values.key_cache[i][:, k]
#             tensors1 = past_key_values.key_cache[i][:, l]
#             cosine_similarities = F.cosine_similarity(tensors0, tensors1, dim=1)
#             print(k, l, cosine_similarities)
#     break