In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# from transformers.generation import GenerationConfig
from hip_attn.models.modeling_llama import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import torch

torch.manual_seed(1234)

<torch._C.Generator at 0x7f8658036a70>

In [2]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", trust_remote_code=True)

In [3]:
model = LlamaForCausalLM.from_pretrained(
    "gradientai/Llama-3-8B-Instruct-Gradient-1048k",
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).eval()

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
import time
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache

In [5]:

with torch.inference_mode():
    ### Simulate Prefill
    start_length = 10240
    chunks = 4
    input_ids = torch.randint(0, tokenizer.vocab_size, (1, start_length)).to('cuda')
    # for epoch in range(4):
    #     past_key_values = OffloadedCache()

    #     for i in range(chunks):
            
    #         torch.cuda.synchronize()
    #         start_time = time.time()

    #         model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)

    #         torch.cuda.synchronize()
    #         epoch_time = time.time() - start_time
    #         print(f"[Epoch {epoch+1}, Chunk {i+1}]   {epoch_time:.2f} seconds")

    #     del past_key_values
    #     torch.cuda.empty_cache()
    #     torch.cuda.synchronize()

    past_key_values = DynamicCache()
    config = model.config

    for idx, layer in enumerate(model.model.layers):
        device = next(model.parameters()).device
        dtype = next(model.parameters()).dtype
        module = layer.self_attn

        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
        head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
        # head_dim = config.hidden_size
        cache_shape = (1, config.num_key_value_heads, start_length, module.head_dim)
        key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
        value_states = torch.zeros(cache_shape, dtype=dtype, device=device)

        key_states, value_states = past_key_values.update(key_states, value_states, module.layer_idx)


    ### Simulate Decode
    start_length = 10240

    max_new_tokens = 10
    input_ids = torch.randint(0, tokenizer.vocab_size, (1, 1)).to('cuda')

    for epoch in range(4):
        for i in range(max_new_tokens):
            torch.cuda.synchronize()
            start_time = time.time()
            
            model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)
            
            torch.cuda.synchronize()
            epoch_time = time.time() - start_time
            print(f"Epoch {i+1}  {epoch_time:.2f} seconds")
        start_length = start_length + 5120

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Epoch 1  0.33 seconds
Epoch 2  0.02 seconds
Epoch 3  0.02 seconds
Epoch 4  0.02 seconds
Epoch 5  0.02 seconds
Epoch 6  0.02 seconds
Epoch 7  0.02 seconds
Epoch 8  0.02 seconds
Epoch 9  0.02 seconds
Epoch 10  0.02 seconds
Epoch 1  0.02 seconds
Epoch 2  0.02 seconds
Epoch 3  0.02 seconds
Epoch 4  0.02 seconds
Epoch 5  0.02 seconds
Epoch 6  0.02 seconds
Epoch 7  0.02 seconds
Epoch 8  0.02 seconds
Epoch 9  0.02 seconds
Epoch 10  0.02 seconds
Epoch 1  0.02 seconds
Epoch 2  0.02 seconds
Epoch 3  0.02 seconds
Epoch 4  0.02 seconds
Epoch 5  0.02 seconds
Epoch 6  0.02 seconds
Epoch 7  0.02 seconds
Epoch 8  0.02 seconds
Epoch 9  0.02 seconds
Epoch 10  0.02 seconds
Epoch 1  0.02 seconds
Epoch 2  0.02 seconds
Epoch 3  0.02 seconds
Epoch 4  0.02 seconds
Epoch 5  0.02 seconds
Epoch 6  0.02 seconds
Epoch 7  0.02 seconds
Epoch 8  0.02 seconds
Epoch 9  0.02 seconds
Epoch 10  0.02 seconds


In [6]:

# context = "A quick brown fox jumps over the lazy dog. \n"
# # with open("demo/duo_attention.txt", "r") as f:
# #     needle = f.read()
# needle="\n\nRemember, the best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n\n"
# num_tokens_context = len(tokenizer.encode(context, add_special_tokens=False))
# num_repetitions = 1000000 // num_tokens_context

# text = (
#     "This is a very long story book: <book> "
#     + context * int(num_repetitions * 0.75)
#     + needle
#     + context * int(num_repetitions * (1 - 0.75))
#     + "what is the best thing to do in San Francisco?\n\nAnswer: The best thing to do in San Francisco is"
# )

# input_ids = tokenizer.encode(text, return_tensors="pt").to("cuda")

In [7]:
# output = model.generate(input_ids, do_sample=False, max_new_tokens=1)
# output = tokenizer.decode(output.cpu()[0], skip_special_tokens=False)
# print(output)

In [8]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n