In [1]:
import transformers
import torch
import time
import shutil
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache

# Load the model
ckpt = "models/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="eager",
).to("cuda")


generation_config = GenerationConfig.from_pretrained(ckpt)
eos_token_ids = generation_config.eos_token_id
if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

# add some tokens like "</user>" and </s> to eos ids
eos_token_ids += tokenizer.encode("</user>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</s>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</", add_special_tokens=False)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:

from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache
import types

# Manually perform inference using KV cache
input_ids = torch.randint(0, tokenizer.vocab_size, (1, 32000)).to('cuda')
next_input_ids = torch.randint(0, tokenizer.vocab_size, (1, 1)).to('cuda')

max_new_tokens = 10
next_new_tokens = 10
generated_tokens = []

# Initialize past_key_values to None
past_key_values = OffloadedCache()


for epoch in range(max_new_tokens):
    with torch.no_grad():
        start_time = time.time()
        outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)
        
        # Extract the logits and past_key_values (the cache)
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

        torch.cuda.empty_cache()
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{max_new_tokens} - Time: {epoch_time:.2f} seconds")
        print(
            "Peak allocated bytes on {:4f}GB".format(
                torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2**30
            )
        )

for epoch in range(next_new_tokens):
    with torch.no_grad():
        start_time = time.time()
        outputs = model(input_ids=next_input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)
        
        # Extract the logits and past_key_values (the cache)
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

        torch.cuda.empty_cache()
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{next_new_tokens} - Time: {epoch_time:.2f} seconds")
        print(
            "Peak allocated bytes on {:4f}GB".format(
                torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2**30
            )
        )

OutOfMemoryError: CUDA out of memory. Tried to allocate 61.04 GiB. GPU 0 has a total capacity of 23.55 GiB of which 3.92 GiB is free. Process 2076925 has 19.62 GiB memory in use. Of the allocated memory 18.23 GiB is allocated by PyTorch, and 934.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)