In [4]:
import os
import sys
sys.path.append('/home/hanshis/workspace/LongContextInfer')

import torch
from models.modeling_llama import LlamaForCausalLM
from transformers import AutoTokenizer

# model = "NousResearch/Yarn-Llama-2-7b-128k"
model = "01-ai/Yi-6B-200K"

target = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16, device_map="cuda:1")
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True, legacy=False)

>>>> Flash Attention installed
>>>> Flash RoPE installed


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
target.model.layers[0].self_attn.rotary_emb

LlamaRotaryEmbedding()

In [6]:
from models.cache_utils import FlashSimpleCache

with torch.inference_mode():
    cache = FlashSimpleCache(target, 1024)
    prompts = "Hello, my dog is cute"
    next_token = tokenizer.encode(prompts, return_tensors="pt").to(target.device)
    for i in range(100):
        if next_token.shape == torch.Size([1]):
            next_token = next_token.unsqueeze(0)
        if next_token.shape == torch.Size([]):
            next_token = next_token.unsqueeze(0).unsqueeze(0)
        logits = target(input_ids=next_token, kv_cache=cache, graph_cache=None).logits
        next_token = torch.argmax(logits[:, -1, :])
        print(tokenizer.decode(next_token), end=" ")

. 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 

In [1]:
import os
import sys, torch
sys.path.append('/home/hanshis/workspace/LongContextInfer')

from models.modeling_batch_llama import LlamaForCausalLM
from models.batch_cache import BatchSimpleCache
from transformers import AutoTokenizer
model = "01-ai/Yi-6B-200K"
target_bsz = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16, device_map="cuda:1")
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True, legacy=False)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
with torch.inference_mode():
    prompts = "Hello, my dog is cute"
    next_token = tokenizer.encode(prompts, return_tensors="pt").to(target_bsz.device)
    cache_bsz = BatchSimpleCache(target_bsz, 1024, 1)
    n=0
    cache_bsz.reset()
    gen_tokens = torch.zeros((1, 100), dtype=torch.long, device=target_bsz.device)
    while n < 100:
        logits = target_bsz(input_ids=next_token, kv_cache=cache_bsz, graph_cache=None).logits
        next_token = torch.argmax(logits[:, -1, :], dim=-1)
        gen_tokens[:, n] = next_token.squeeze()
        next_token = next_token.unsqueeze(1)
        n += 1
    tokenizer.batch_decode(gen_tokens)

In [3]:
print(tokenizer.batch_decode(gen_tokens))

['.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n']


In [None]:
torch.allclose(target.model.layers[1].self_attn.rotary_emb.cos_cached, target_bsz.model.layers[1].self_attn.rotary_emb.cos_cached)

True

In [None]:
cache.key_cache[0,:,:2].shape

torch.Size([1, 2, 4, 128])

In [None]:
torch.allclose(cache.key_cache[0,:,:1], cache_bsz.key_cache[0,:,:1])

True

In [None]:
torch.allclose(cache.key_cache[1,:,:1], cache_bsz.key_cache[1,:,:1])

False

In [None]:
import torch

def check_models_identical(model1, model2):
    # 检查模型参数
    params1 = list(model1.parameters())
    params2 = list(model2.parameters())
    if len(params1) != len(params2):
        print("false")
        return False
    for p1, p2 in zip(params1, params2):
        if p1.data.ne(p2.data).sum() > 0:  # 检查是否有不同的元素
            print("false")
            return False
    
    # 检查模型的缓冲区
    buffers1 = list(model1.buffers())
    buffers2 = list(model2.buffers())
    if len(buffers1) != len(buffers2):
        print("false")
        return False
    for b1, b2 in zip(buffers1, buffers2):
        if b1.data.ne(b2.data).sum() > 0:  # 检查是否有不同的元素
            print("false")
            return False
    
    # 如果所有检查都通过了，则模型是完全相同的
    return True

check_models_identical(target, target_bsz)

True

In [3]:
for i in [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32]:
    print(f"CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz {i} --prefill {122880*8//i}")

CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 2 --prefill 491520
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 4 --prefill 245760
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 6 --prefill 163840
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 8 --prefill 122880
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 10 --prefill 98304
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 12 --prefill 81920
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 14 --prefill 70217
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 16 --prefill 61440
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 18 --prefill 54613
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 20 --prefill 49152
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 22 --prefill 44683
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 24 --prefill 40960
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 26 --