In [1]:
import transformers
import torch
import time
import shutil
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Load the model
ckpt = "models/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
).to("cuda")


generation_config = GenerationConfig.from_pretrained(ckpt)
eos_token_ids = generation_config.eos_token_id
if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

# add some tokens like "</user>" and </s> to eos ids
eos_token_ids += tokenizer.encode("</user>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</s>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</", add_special_tokens=False)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:

context = "A quick brown fox jumps over the lazy dog. \n"
# with open("demo/duo_attention.txt", "r") as f:
#     needle = f.read()
needle = "Mary's favorite number is 34251"
num_tokens_context = len(tokenizer.encode(context, add_special_tokens=False))
num_repetitions = 10000 // num_tokens_context

text = (
    "This is a very long story book: <book> "
    + context * int(num_repetitions * 0.75)
    + needle
    + context * int(num_repetitions * (1 - 0.75))
    + "</book>\n Based on the content of the book, please briefly tell me about what is Mary's favorite number.\nAnswer:"
)

input_ids = tokenizer.encode(text, return_tensors="pt").to("cuda")

print(f"Input sequence length: {input_ids.size(1)}\n")

Input sequence length: 10033



In [3]:

from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache
# Manually perform inference using KV cache

# inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
max_new_tokens = 500
generated_tokens = []
# input_ids = inputs["input_ids"]

# Initialize past_key_values to None
past_key_values = DynamicCache()
past_key_values.sink_size = 64
past_key_values.recent_size = 256


for _ in range(max_new_tokens):
    with torch.no_grad():
        outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
        
        # Extract the logits and past_key_values (the cache)
        next_token_logits = outputs.logits[:, -1, :]  # Logits of the last token
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step


        # Greedy decoding: get the token with the highest probability
        next_token = torch.argmax(next_token_logits, dim=-1)
        if next_token.item() in eos_token_ids:
            break
        generated_tokens.append(next_token.item())

        # Only pass the new token for the next iteration
        input_ids = next_token.unsqueeze(-1)

# Convert generated token ids to text
output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
print(output_text)

 Mary's favorite number is 34251. It is not mentioned in the book why it is her favorite number. The book is just a repetition of the sentence "A quick brown fox jumps over the lazy dog" many times. It is not related to Mary's favorite number. The book is just a long story book.
