In [1]:
import transformers
import torch
import time
import shutil
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Load the model
ckpt = "gradientai/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
).to("cuda")

generation_config = GenerationConfig.from_pretrained(ckpt)
eos_token_ids = generation_config.eos_token_id
if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

# add some tokens like "</user>" and </s> to eos ids
eos_token_ids += tokenizer.encode("</user>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</s>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</", add_special_tokens=False)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [21]:
from transformers.models.llama.modeling_llama import (
    logger,
    apply_rotary_pos_emb,
    repeat_kv,
    LlamaSdpaAttention,
    LlamaFlashAttention2,
)
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache
import types
from transformers.modeling_flash_attention_utils  import _flash_attention_forward
from snapkv.monkeypatch.snapkv_utils import init_snapkv

def LlamaAttention_fast_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) :
    init_snapkv(self)
    output_attentions = False

    bsz, q_len, hd = hidden_states.size()
    chunk_size = hd // self.num_key_value_heads
    num_heads = self.num_heads // self.num_key_value_heads

    if not hasattr(self, 'q_proj_list'):
        self.q_proj_list = list((self.q_proj.weight.split(self.head_dim * num_heads, dim=0)))
        # self.q_proj.weight.data.storage().resize_(0)
    if not hasattr(self, 'k_proj_list'):
        self.k_proj_list = list((self.k_proj.weight.split(self.head_dim, dim=0)))
        # self.k_proj.weight.data.storage().resize_(0)
    if not hasattr(self, 'v_proj_list'):
        self.v_proj_list = list((self.v_proj.weight.split(self.head_dim, dim=0)))
        # self.v_proj.weight.data.storage().resize_(0)


    attn_output_list = [None for _ in range((self.num_key_value_heads))]
    
    for i in range(self.num_key_value_heads):
        bsz, q_len, hd = hidden_states.size()

        self.q_proj.weight.data = self.q_proj_list[i].data
        self.k_proj.weight.data = self.k_proj_list[i].data
        self.v_proj.weight.data = self.v_proj_list[i].data

        # print(hidden_states.shape, self.q_proj.weight.shape)
        
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, 1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, 1, self.head_dim).transpose(1, 2)

        if i == 0:
            kv_seq_len = key_states.shape[-2]
            if hasattr(self, "kv_seq_len"):
                self.kv_seq_len += kv_seq_len
            else:
                self.kv_seq_len = kv_seq_len
                
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    
        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
                key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
                key_states, value_states = past_key_value.update(key_states_compress, value_states_compress, self.layer_idx + i, cache_kwargs)
            else:
                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx + i, cache_kwargs)
    
        # to be able to avoid many of these transpose/reshape/view.
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)
    
        attn_output = _flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            position_ids=position_ids,
            dropout=0.0,
            sliding_window=getattr(self, "sliding_window", None),
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
            is_causal=self.is_causal,
        )
    
        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output_list[i] = attn_output
        
    attn_output = torch.cat(attn_output_list, dim=-1)
    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value


LlamaFlashAttention2.forward = LlamaAttention_fast_forward
layer_idx = 0
for idx, layer in enumerate(model.model.layers):
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype
    module = layer.self_attn
    module.layer_idx = layer_idx
    layer_idx += module.num_key_value_heads
    module.head_group = 8

In [22]:

from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache

# Initialize past_key_values to None
past_key_values = OffloadedCache()
# Manually perform inference using KV cache

inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
max_new_tokens = 23
generated_tokens = []
input_ids = inputs["input_ids"]

for _ in range(max_new_tokens):
    with torch.no_grad():
        outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
        
        # Extract the logits and past_key_values (the cache)
        next_token_logits = outputs.logits[:, -1, :]  # Logits of the last token
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

        # Greedy decoding: get the token with the highest probability
        next_token = torch.argmax(next_token_logits, dim=-1)
        generated_tokens.append(next_token.item())

        # Only pass the new token for the next iteration
        input_ids = next_token.unsqueeze(-1)

# Convert generated token ids to text
output_text = "Fun fact: The shortest" + tokenizer.decode(generated_tokens, skip_special_tokens=True)
print(output_text)

Fun fact: The shortest war in history was between Zanzibar and Great Britain on August 27, 1896. Zanzibar
