In [1]:
import transformers
import torch
import time
import shutil
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Load the model
ckpt = "models/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
).to("cuda")

generation_config = GenerationConfig.from_pretrained(ckpt)
eos_token_ids = generation_config.eos_token_id
if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

# add some tokens like "</user>" and </s> to eos ids
eos_token_ids += tokenizer.encode("</user>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</s>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</", add_special_tokens=False)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
from transformers.models.llama.modeling_llama import (
    logger,
    apply_rotary_pos_emb,
    repeat_kv,
    LlamaSdpaAttention,
    LlamaFlashAttention2,
    LlamaMLP
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.modeling_flash_attention_utils  import _flash_attention_forward

def minis_mlp_forward(self, x):
    bsz, q_len, _ = x.size()
    chunk_size = self.hidden_size

    x_list = list(x.split(chunk_size, dim=1))

    output_list = [None for _ in range(len(x_list))]

    for i in range(len(x_list)):
        x = x_list[i]
        x_list[i] = None
        output_list[i] = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

    down_proj = torch.cat(output_list, dim=1)

    return down_proj
    
LlamaMLP.forward = minis_mlp_forward

In [6]:

# def LlamaAttention_fast_forward(
#     self,
#     hidden_states: torch.Tensor,
#     attention_mask: Optional[torch.Tensor] = None,
#     position_ids: Optional[torch.LongTensor] = None,
#     past_key_value: Optional[Cache] = None,
#     output_attentions: bool = False,
#     use_cache: bool = False,
#     cache_position: Optional[torch.LongTensor] = None,
#     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
#     **kwargs,
# ) :
#     output_attentions = False

#     bsz, q_len, hd = hidden_states.size()
#     chunk_size = hd // self.num_key_value_heads
#     num_heads = self.num_heads // self.num_key_value_heads

#     if not hasattr(self, 'q_proj_list'):
#         self.q_proj_list = list((self.q_proj.weight.split(self.head_dim * num_heads, dim=0)))
#         # self.q_proj.weight.data.storage().resize_(0)
#     if not hasattr(self, 'k_proj_list'):
#         self.k_proj_list = list((self.k_proj.weight.split(self.head_dim, dim=0)))
#         # self.k_proj.weight.data.storage().resize_(0)
#     if not hasattr(self, 'v_proj_list'):
#         self.v_proj_list = list((self.v_proj.weight.split(self.head_dim, dim=0)))
#         # self.v_proj.weight.data.storage().resize_(0)


#     attn_output_list = [None for _ in range((self.num_key_value_heads))]
    
#     for i in range(self.num_key_value_heads):
#         bsz, q_len, hd = hidden_states.size()

#         self.q_proj.weight.data = self.q_proj_list[i].data
#         self.k_proj.weight.data = self.k_proj_list[i].data
#         self.v_proj.weight.data = self.v_proj_list[i].data

#         # print(hidden_states.shape, self.q_proj.weight.shape)
        
#         query_states = self.q_proj(hidden_states)
#         key_states = self.k_proj(hidden_states)
#         value_states = self.v_proj(hidden_states)

#         # Flash attention requires the input to have the shape
#         # batch_size x seq_length x head_dim x hidden_dim
#         # therefore we just need to keep the original shape
#         query_states = query_states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
#         key_states = key_states.view(bsz, q_len, 1, self.head_dim).transpose(1, 2)
#         value_states = value_states.view(bsz, q_len, 1, self.head_dim).transpose(1, 2)
    
#         if position_embeddings is None:
#             logger.warning_once(
#                 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
#                 "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
#                 "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
#                 "removed and `position_embeddings` will be mandatory."
#             )
#             cos, sin = self.rotary_emb(value_states, position_ids)
#         else:
#             cos, sin = position_embeddings
#         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    
#         if past_key_value is not None:
#             # sin and cos are specific to RoPE models; cache_position needed for the static cache
#             cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
#             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx + i, cache_kwargs)
    
#         # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
#         # to be able to avoid many of these transpose/reshape/view.
#         query_states = query_states.transpose(1, 2)
#         key_states = key_states.transpose(1, 2)
#         value_states = value_states.transpose(1, 2)
    
#         dropout_rate = self.attention_dropout if self.training else 0.0
    
#         # In PEFT, usually we cast the layer norms in float32 for training stability reasons
#         # therefore the input hidden states gets silently casted in float32. Hence, we need
#         # cast them back in the correct dtype just to be sure everything works as expected.
#         # This might slowdown training & inference so it is recommended to not cast the LayerNorms
#         # in fp32. (LlamaRMSNorm handles it correctly)
    
#         input_dtype = query_states.dtype
#         if input_dtype == torch.float32:
#             if torch.is_autocast_enabled():
#                 target_dtype = torch.get_autocast_gpu_dtype()
#             # Handle the case where the model is quantized
#             elif hasattr(self.config, "_pre_quantization_dtype"):
#                 target_dtype = self.config._pre_quantization_dtype
#             else:
#                 target_dtype = self.q_proj.weight.dtype
    
#             logger.warning_once(
#                 f"The input hidden states seems to be silently casted in float32, this might be related to"
#                 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
#                 f" {target_dtype}."
#             )
    
#             query_states = query_states.to(target_dtype)
#             key_states = key_states.to(target_dtype)
#             value_states = value_states.to(target_dtype)
    
#         attn_output = _flash_attention_forward(
#             query_states,
#             key_states,
#             value_states,
#             attention_mask,
#             q_len,
#             position_ids=position_ids,
#             dropout=dropout_rate,
#             sliding_window=getattr(self, "sliding_window", None),
#             use_top_left_mask=self._flash_attn_uses_top_left_mask,
#             is_causal=self.is_causal,
#         )
    
#         attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
#         attn_output_list[i] = attn_output
        
#     attn_output = torch.cat(attn_output_list, dim=-1)
#     attn_output = self.o_proj(attn_output)

#     if not output_attentions:
#         attn_weights = None

#     return attn_output, attn_weights, past_key_value


# LlamaFlashAttention2.forward = LlamaAttention_fast_forward
# layer_idx = 0
# for name, module in model.named_modules():
#     if "self_attn" in name and hasattr(module, "q_proj"):
#         module.layer_idx = layer_idx
#         layer_idx += module.num_key_value_heads

In [4]:

from torch import nn

class OffloadedCache(DynamicCache):
    """
    A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
    Useful for generating from models with very long context.

    In addition to the default CUDA stream, where all forward() computations happen,
    this class uses another stream, the prefetch stream, which it creates itself.
    Since scheduling of operations on separate streams happens independently, this class uses
    the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
    The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
    ensure the eviction is scheduled after all computations on that cache are finished.
    """

    def __init__(self) -> None:
        if not torch.cuda.is_available():
            raise RuntimeError("OffloadedCache can only be used with a GPU")
        super().__init__()
        self.original_device = []
        self.prefetch_stream = torch.cuda.Stream()
        self.beam_idx = None  # used to delay beam search operations

        self.proj = nn.Linear(128, 128, bias=False, dtype=torch.bfloat16).cuda()

    def prefetch_layer(self, layer_idx: int):
        "Starts prefetching the next layer cache"
        if layer_idx < len(self):
            with torch.cuda.stream(self.prefetch_stream):
                # Prefetch next layer tensors to GPU
                device = self.original_device[layer_idx]
                self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
                self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)

    def evict_previous_layer(self, layer_idx: int):
        "Moves the previous layer cache to the CPU"
        if len(self) > 2:
            # We do it on the default stream so it occurs after all earlier computations on these tensors are done
            prev_layer_idx = (layer_idx - 1) % len(self)
            self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
            self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
        if layer_idx < len(self):
            # Evict the previous layer if necessary
            torch.cuda.current_stream().synchronize()
            self.evict_previous_layer(layer_idx)
            # Load current layer cache to its original device if not already there
            original_device = self.original_device[layer_idx]
            self.prefetch_stream.synchronize()
            key_tensor = self.key_cache[layer_idx]
            value_tensor = self.value_cache[layer_idx]
            # Now deal with beam search ops which were delayed
            if self.beam_idx is not None:
                self.beam_idx = self.beam_idx.to(original_device)
                key_tensor = key_tensor.index_select(0, self.beam_idx)
                value_tensor = value_tensor.index_select(0, self.beam_idx)
            # Prefetch the next layer
            self.prefetch_layer((layer_idx + 1) % len(self))
            return (key_tensor, value_tensor)
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Saves the beam indices and reorders the cache when the tensor is back to its device."""
        # We delay this operation until the tensors are back to their original
        # device because performing torch.index_select on the CPU is very slow
        del self.beam_idx
        self.beam_idx = beam_idx.clone()

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) < layer_idx // 8:
            raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
        elif len(self.key_cache) == layer_idx // 8:
            if layer_idx % 8 == 0:
                layer_idx = layer_idx // 8
                self.key_cache.append(key_states)
                self.value_cache.append(value_states)
                self.original_device.append(key_states.device)
                self.evict_previous_layer(layer_idx)
            else:
                layer_idx = layer_idx // 8 
        else:
            if layer_idx % 8 == 0:
                layer_idx = layer_idx // 8
                key_tensor, value_tensor = self[layer_idx]
                self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
            else:
                layer_idx = layer_idx // 8 
                key = self.proj(self.key_cache[layer_idx])
                value = self.proj(self.value_cache[layer_idx])
                return key, value

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
    # if a method is not supposed to be supported in a subclass we should set it to None
    from_legacy_cache = None

    to_legacy_cache = None


In [5]:
# Manually perform inference using KV cache
input_ids = torch.randint(0, tokenizer.vocab_size, (1, 32000)).to('cuda')
next_input_ids = torch.randint(0, tokenizer.vocab_size, (1, 1)).to('cuda')

max_new_tokens = 1
next_new_tokens = 10
generated_tokens = []

# # Initialize past_key_values to None
# past_key_values = OffloadedCache()
# past_key_values.sink_size = 64
# past_key_values.recent_size = 256

# config = model.config
# for idx, layer in enumerate(model.model.layers):
#     device = next(model.parameters()).device
#     dtype = next(model.parameters()).dtype
#     module = layer.self_attn

#     # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
#     # head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
#     head_dim = config.hidden_size
    
#     cache_shape = (1, 1, 32000 * (max_new_tokens)  + 15, head_dim)
#     key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
#     value_states = torch.zeros(cache_shape, dtype=dtype, device=device)

#     # for i in range(config.num_key_value_heads):
#     cache_kwargs = {'full_head': None}
#     key_states, value_states = past_key_values.update(key_states, value_states, module.layer_idx, cache_kwargs)


# Initialize past_key_values to None
past_key_values = DynamicCache()

for epoch in range(max_new_tokens):
    with torch.no_grad():
        start_time = time.time()
        outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)
        
        # Extract the logits and past_key_values (the cache)
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

        torch.cuda.empty_cache()
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{max_new_tokens} - Time: {epoch_time:.2f} seconds")
        print(
            "Peak allocated bytes on {:4f}GB".format(
                torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2**30
            )
        )


for epoch in range(next_new_tokens):
    with torch.no_grad():
        start_time = time.time()
        outputs = model(input_ids=next_input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)
        
        # Extract the logits and past_key_values (the cache)
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

        torch.cuda.empty_cache()
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{next_new_tokens} - Time: {epoch_time:.2f} seconds")
        print(
            "Peak allocated bytes on {:4f}GB".format(
                torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2**30
            )
        )

Epoch 1/1 - Time: 3.90 seconds
Peak allocated bytes on 20.598151GB
Epoch 1/10 - Time: 0.16 seconds
Peak allocated bytes on 20.598151GB
Epoch 2/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
Epoch 3/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
Epoch 4/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
Epoch 5/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
Epoch 6/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
Epoch 7/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
Epoch 8/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
Epoch 9/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
Epoch 10/10 - Time: 0.13 seconds
Peak allocated bytes on 20.598151GB
