In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# from transformers.generation import GenerationConfig
from hip_attn.models.modeling_llama import LlamaForCausalLM
# from transformers.models.llama.modeling_llama import LlamaForCausalLM
import torch

torch.manual_seed(1234)

<torch._C.Generator at 0x77616457db50>

In [2]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", trust_remote_code=True)

In [3]:
model = LlamaForCausalLM.from_pretrained(
    "gradientai/Llama-3-8B-Instruct-Gradient-1048k",
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).eval()

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
LlamaForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
import time
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache

In [5]:

with torch.inference_mode():
    ### Simulate Prefill
    start_length = 1024
    chunks = 1
    input_ids = torch.randint(0, tokenizer.vocab_size, (1, start_length)).to('cuda')
    for epoch in range(4):
        past_key_values = DynamicCache()

        for i in range(chunks):
            
            torch.cuda.synchronize()
            start_time = time.time()

            model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)

            torch.cuda.synchronize()
            epoch_time = time.time() - start_time
            print(f"[Epoch {epoch+1}, Chunk {i+1}]   {epoch_time:.2f} seconds")

        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    ### Simulate Decode
    start_length = 10240

    max_new_tokens = 100
    input_ids = torch.randint(0, tokenizer.vocab_size, (1, 1)).to('cuda')

    for epoch in range(4):
        torch.cuda.synchronize()
        start_time = time.time()
        for i in range(max_new_tokens):
            
            model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1)
            
        torch.cuda.synchronize()
        epoch_time = (time.time() - start_time)/max_new_tokens
        print(f"Epoch {i+1}  {epoch_time:.2f} seconds")
        start_length = start_length + 5120

[Epoch 1, Chunk 1]   1.21 seconds
[Epoch 2, Chunk 1]   0.12 seconds
[Epoch 3, Chunk 1]   0.12 seconds
[Epoch 4, Chunk 1]   0.12 seconds
Epoch 100  0.04 seconds
Epoch 100  0.04 seconds
Epoch 100  0.03 seconds
Epoch 100  0.03 seconds


In [6]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaCustomAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
          (hyper_attention): HyperAttention(
            (lsh): AngularLSH(num_proj=7, proj_dir.shape=torch.Size([1, 1, 128, 7]))
          )
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
     