In [4]:
import os
import sys
sys.path.append('/home/hanshis/workspace/LongContextInfer')

import torch
from models.modeling_llama import LlamaForCausalLM
from transformers import AutoTokenizer

# model = "NousResearch/Yarn-Llama-2-7b-128k"
model = "01-ai/Yi-6B-200K"

target = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16, device_map="cuda:1")
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True, legacy=False)

>>>> Flash Attention installed
>>>> Flash RoPE installed


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
target.model.layers[0].self_attn.rotary_emb

LlamaRotaryEmbedding()

In [6]:
from models.cache_utils import FlashSimpleCache

with torch.inference_mode():
    cache = FlashSimpleCache(target, 1024)
    prompts = "Hello, my dog is cute"
    next_token = tokenizer.encode(prompts, return_tensors="pt").to(target.device)
    for i in range(100):
        if next_token.shape == torch.Size([1]):
            next_token = next_token.unsqueeze(0)
        if next_token.shape == torch.Size([]):
            next_token = next_token.unsqueeze(0).unsqueeze(0)
        logits = target(input_ids=next_token, kv_cache=cache, graph_cache=None).logits
        next_token = torch.argmax(logits[:, -1, :])
        print(tokenizer.decode(next_token), end=" ")

. 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 
 I  have  a  dog . 
 

In [1]:
import os
import sys, torch
sys.path.append('/home/hanshis/workspace/LongContextInfer')

from models.modeling_batch_llama import LlamaForCausalLM
from models.batch_cache import BatchSimpleCache
from transformers import AutoTokenizer
model = "01-ai/Yi-6B-200K"
target_bsz = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16, device_map="cuda:1")
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True, legacy=False)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
with torch.inference_mode():
    prompts = "Hello, my dog is cute"
    next_token = tokenizer.encode(prompts, return_tensors="pt").to(target_bsz.device)
    cache_bsz = BatchSimpleCache(target_bsz, 1024, 1)
    n=0
    cache_bsz.reset()
    gen_tokens = torch.zeros((1, 100), dtype=torch.long, device=target_bsz.device)
    while n < 100:
        logits = target_bsz(input_ids=next_token, kv_cache=cache_bsz, graph_cache=None).logits
        next_token = torch.argmax(logits[:, -1, :], dim=-1)
        gen_tokens[:, n] = next_token.squeeze()
        next_token = next_token.unsqueeze(1)
        n += 1
    tokenizer.batch_decode(gen_tokens)

In [3]:
print(tokenizer.batch_decode(gen_tokens))

['.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n\nI have a dog.\n']


In [None]:
torch.allclose(target.model.layers[1].self_attn.rotary_emb.cos_cached, target_bsz.model.layers[1].self_attn.rotary_emb.cos_cached)

True

In [None]:
cache.key_cache[0,:,:2].shape

torch.Size([1, 2, 4, 128])

In [None]:
torch.allclose(cache.key_cache[0,:,:1], cache_bsz.key_cache[0,:,:1])

True

In [None]:
torch.allclose(cache.key_cache[1,:,:1], cache_bsz.key_cache[1,:,:1])

False

In [None]:
import torch

def check_models_identical(model1, model2):
    # 检查模型参数
    params1 = list(model1.parameters())
    params2 = list(model2.parameters())
    if len(params1) != len(params2):
        print("false")
        return False
    for p1, p2 in zip(params1, params2):
        if p1.data.ne(p2.data).sum() > 0:  # 检查是否有不同的元素
            print("false")
            return False
    
    # 检查模型的缓冲区
    buffers1 = list(model1.buffers())
    buffers2 = list(model2.buffers())
    if len(buffers1) != len(buffers2):
        print("false")
        return False
    for b1, b2 in zip(buffers1, buffers2):
        if b1.data.ne(b2.data).sum() > 0:  # 检查是否有不同的元素
            print("false")
            return False
    
    # 如果所有检查都通过了，则模型是完全相同的
    return True

check_models_identical(target, target_bsz)

True

In [3]:
for i in [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32]:
    print(f"CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz {i} --prefill {122880*8//i}")

CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 2 --prefill 491520
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 4 --prefill 245760
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 6 --prefill 163840
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 8 --prefill 122880
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 10 --prefill 98304
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 12 --prefill 81920
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 14 --prefill 70217
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 16 --prefill 61440
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 18 --prefill 54613
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 20 --prefill 49152
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 22 --prefill 44683
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 24 --prefill 40960
CUDA_VISIBLE_DEVICES=0 python benchmark/batch_mqa.py --bsz 26 --

In [1]:
import torch
from flash_attn import flash_attn_with_kvcache
from transformers.models.llama.modeling_llama import(
    repeat_kv,
)
import torch.nn.functional as F
import math
import time

head_dim = 128

def benchmark_mqa_attn(attn_method, query_states, key_states, value_states, num_key_value_groups, seq_len):
    if attn_method == 'flash':
        time1 = time.time()
        for i in range(1000):
            attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, cache_seqlens=seq_len, softmax_scale=1/torch.sqrt(torch.tensor(head_dim, dtype=torch.float16)), causal=True)
        latency = (time.time()-time1)
    
    elif attn_method == 'flash_repeat':
        time1 = time.time()
        for i in range(1000):
            key_states = repeat_kv(key_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            value_states = repeat_kv(value_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, cache_seqlens=seq_len, softmax_scale=1/torch.sqrt(torch.tensor(head_dim, dtype=torch.float16)), causal=True)
        latency = (time.time()-time1)

    elif attn_method == 'sdpa':
        time1 = time.time()
        for i in range(1000):
            key_states = repeat_kv(key_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            value_states = repeat_kv(value_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            with torch.backends.cuda.sdp_kernel(enable_math=False):
                attn_output = F.scaled_dot_product_attention(query_states.transpose(1,2),key_states.transpose(1,2),value_states.transpose(1,2), is_causal=True)
        latency = (time.time()-time1)
    
    elif attn_method == 'vanilla':
        time1 = time.time()
        for i in range(1000):
            key_states = repeat_kv(key_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            value_states = repeat_kv(value_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            query_states= query_states.transpose(1,2)
            key_states= key_states.transpose(1,2)
            value_states= value_states.transpose(1,2)
            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states)
            attn_output = attn_output.transpose(1, 2).contiguous()
        latency = (time.time()-time1)

    return latency

In [26]:
import torch
from flash_attn import flash_attn_with_kvcache
from transformers.models.llama.modeling_llama import(
    repeat_kv,
)
import torch.nn.functional as F
import math
import time

def benchmark_mqa_attn(attn_method, query_states, key_states, value_states, num_key_value_groups, seq_len):
    bsz, kv_len, kv_heads, head_dim = key_states.shape
    if attn_method == 'flash':
        # warm up
        for i in range(100):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, cache_seqlens=seq_len, softmax_scale=1/torch.sqrt(torch.tensor(head_dim, dtype=torch.float16)), causal=True)
        torch.cuda.synchronize()
        T = 2000
        time1 = time.time()
        for i in range(T):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, cache_seqlens=seq_len, softmax_scale=1/torch.sqrt(torch.tensor(head_dim, dtype=torch.float16)), causal=True)
        torch.cuda.synchronize()
        latency = (time.time()-time1) / T * 1000

    elif attn_method == 'flash-ref':
        # warm up
        for i in range(100):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            flash_attn_with_kvcache(q=query_states[:,:,:4], k_cache=key_states, v_cache=value_states, cache_seqlens=seq_len, softmax_scale=1/torch.sqrt(torch.tensor(head_dim, dtype=torch.float16)), causal=True)
        torch.cuda.synchronize()
        T = 2000
        time1 = time.time()
        for i in range(T):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            attn_output = flash_attn_with_kvcache(q=query_states[:,:,:4], k_cache=key_states, v_cache=value_states, cache_seqlens=seq_len, softmax_scale=1/torch.sqrt(torch.tensor(head_dim, dtype=torch.float16)), causal=True)
        torch.cuda.synchronize()
        latency = (time.time()-time1) / T * 1000
    
    elif attn_method == 'vanilla':
        bsz, kv_len, kv_heads, head_dim = key_states.shape
        # warm up
        for i in range(100):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            key_states_cp = repeat_kv(key_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            value_states_cp = repeat_kv(value_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            attn_weights = torch.matmul(query_states.transpose(1,2), key_states_cp.transpose(1,2).transpose(2, 3)) / math.sqrt(head_dim)
            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states_cp.transpose(1,2))
            attn_output = attn_output.transpose(1, 2).contiguous()
        torch.cuda.synchronize()
        T = 2000
        time1 = time.time()
        for i in range(T):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            key_states_cp = repeat_kv(key_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            value_states_cp = repeat_kv(value_states.transpose(1,2), num_key_value_groups).transpose(1,2)
            attn_weights = torch.matmul(query_states.transpose(1,2), key_states_cp.transpose(1,2).transpose(2, 3)) / math.sqrt(head_dim)
            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states_cp.transpose(1,2))
            attn_output = attn_output.transpose(1, 2).contiguous()
        torch.cuda.synchronize()
        latency = (time.time()-time1) / T * 1000

    elif attn_method == 'vanilla-ref':
        bsz, kv_len, kv_heads, head_dim = key_states.shape
        # warm up
        for i in range(100):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            attn_weights = torch.matmul(query_states[:,:,:4].transpose(1,2), key_states.transpose(1,2).transpose(2, 3)) / math.sqrt(head_dim)
            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states.transpose(1,2))
            attn_output = attn_output.transpose(1, 2).contiguous()
        torch.cuda.synchronize()
        T = 2000
        time1 = time.time()
        for i in range(T):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            attn_weights = torch.matmul(query_states[:,:,:4].transpose(1,2), key_states.transpose(1,2).transpose(2, 3)) / math.sqrt(head_dim)
            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states.transpose(1,2))
            attn_output = attn_output.transpose(1, 2).contiguous()
        torch.cuda.synchronize()
        latency = (time.time()-time1) / T * 1000

    elif attn_method == 'optim':
        bsz, kv_len, kv_heads, head_dim = key_states.shape
        _, query_len, _, _ = query_states.shape
        for i in range(100):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            query_states_cp = query_states.transpose(1,2).reshape(bsz, kv_heads, num_key_value_groups*query_len, head_dim)
            attn_weights = torch.matmul(query_states_cp, key_states.transpose(1,2).transpose(2, 3)) / math.sqrt(head_dim) # [bsz, 4, 8*seq, prefill+seq]
            # [TODO] add attn mask here....
            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states.transpose(1,2)) # [bsz, 4, 8*seq, 128]
            attn_output = attn_output.reshape(bsz, kv_heads*num_key_value_groups, query_len, head_dim).transpose(1, 2).contiguous() # [bsz, seq, 32, 128]
        torch.cuda.synchronize()
        T = 2000
        time1 = time.time()
        for i in range(T):
            query_states.normal_()
            key_states.normal_()
            value_states.normal_()
            query_states_cp = query_states.transpose(1,2).reshape(bsz, kv_heads, num_key_value_groups*query_len, head_dim)
            attn_weights = torch.matmul(query_states_cp, key_states.transpose(1,2).transpose(2, 3)) / math.sqrt(head_dim) # [bsz, 4, 8*seq, prefill+seq]
            # [TODO] add attn mask here....
            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states.transpose(1,2)) # [bsz, 4, 8*seq, 128]
            attn_output = attn_output.reshape(bsz, kv_heads*num_key_value_groups, query_len, head_dim).transpose(1, 2).contiguous() # [bsz, seq, 32, 128]
        torch.cuda.synchronize()
        latency = (time.time()-time1) / T * 1000

    return latency


num_key_value_groups = 8
prefill = 1024*64

# bsz_list = [1,2,3,4,5,6,7,8]
bsz_list = [1,2,4,8]
input_list = [1,2,4,8,16]

for bsz in bsz_list:
    for input in input_list:
        seq_len = torch.tensor([prefill]*bsz, dtype=torch.int32, device="cuda:1")
        query_states = torch.randn(bsz, input, 32, 128, dtype=torch.float16, device="cuda:1")
        key_states = torch.randn(bsz, prefill+input, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")
        value_states = torch.randn(bsz, prefill+input, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")
        latency = benchmark_mqa_attn('flash', query_states, key_states, value_states, num_key_value_groups, seq_len)
        print(f"bsz: {bsz}, prefill: {seq_len.cpu().numpy()}, input: {input}, flash: {latency}")
        latency = benchmark_mqa_attn('vanilla', query_states, key_states, value_states, num_key_value_groups, seq_len)
        print(f"bsz: {bsz}, prefill: {seq_len.cpu().numpy()}, input: {input}, vanilla: {latency}")
        latency = benchmark_mqa_attn('optim', query_states, key_states, value_states, num_key_value_groups, seq_len)
        print(f"bsz: {bsz}, prefill: {seq_len.cpu().numpy()}, input: {input}, optim: {latency}")
        latency = benchmark_mqa_attn('flash-ref', query_states, key_states, value_states, num_key_value_groups, seq_len)
        print(f"bsz: {bsz}, prefill: {seq_len.cpu().numpy()}, input: {input}, flash-ref: {latency}")
        latency = benchmark_mqa_attn('vanilla-ref', query_states, key_states, value_states, num_key_value_groups, seq_len)
        print(f"bsz: {bsz}, prefill: {seq_len.cpu().numpy()}, input: {input}, vanilla-ref: {latency}")
        print("=======================================================================")
    print("***********************************************************************")

bsz: 1, prefill: [65536], input: 1, flash: 0.33137381076812744
bsz: 1, prefill: [65536], input: 1, vanilla: 4.519198656082153
bsz: 1, prefill: [65536], input: 1, optim: 0.9695037603378296
bsz: 1, prefill: [65536], input: 1, flash-ref: 0.3322429656982422
bsz: 1, prefill: [65536], input: 1, vanilla-ref: 0.587114691734314
bsz: 1, prefill: [65536], input: 2, flash: 0.7587827444076538
bsz: 1, prefill: [65536], input: 2, vanilla: 3.9973089694976807
bsz: 1, prefill: [65536], input: 2, optim: 0.750908374786377
bsz: 1, prefill: [65536], input: 2, flash-ref: 0.33817434310913086
bsz: 1, prefill: [65536], input: 2, vanilla-ref: 0.7119020223617554
bsz: 1, prefill: [65536], input: 4, flash: 0.760932207107544
bsz: 1, prefill: [65536], input: 4, vanilla: 4.250280380249023
bsz: 1, prefill: [65536], input: 4, optim: 0.8241993188858032
bsz: 1, prefill: [65536], input: 4, flash-ref: 0.3522707223892212
bsz: 1, prefill: [65536], input: 4, vanilla-ref: 0.7200660705566406
bsz: 1, prefill: [65536], input: 8, f

In [12]:
bsz = 2
num_key_value_groups = 8
prefill = 122880
seq_len = torch.tensor([prefill]*bsz, dtype=torch.int32, device="cuda:1")


query_states = torch.randn(bsz, 1, 32, 128, dtype=torch.float16, device="cuda:1")
key_states = torch.randn(bsz, prefill+1, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")
value_states = torch.randn(bsz, prefill+1, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")

benchmark_mqa_attn('flash', query_states, key_states, value_states, num_key_value_groups, seq_len)

0.3586692810058594

In [4]:
bsz = 3
num_key_value_groups = 8
prefill = 122880
seq_len = torch.tensor([prefill]*bsz, dtype=torch.int32, device="cuda:1")


query_states = torch.randn(bsz, 1, 32, 128, dtype=torch.float16, device="cuda:1")
key_states = torch.randn(bsz, prefill+1, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")
value_states = torch.randn(bsz, prefill+1, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")

benchmark_mqa_attn('flash', query_states, key_states, value_states, num_key_value_groups, seq_len)

0.5221462249755859

In [5]:
bsz = 4
num_key_value_groups = 8
prefill = 122880
seq_len = torch.tensor([prefill]*bsz, dtype=torch.int32, device="cuda:1")


query_states = torch.randn(bsz, 1, 32, 128, dtype=torch.float16, device="cuda:1")
key_states = torch.randn(bsz, prefill+1, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")
value_states = torch.randn(bsz, prefill+1, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")

benchmark_mqa_attn('flash', query_states, key_states, value_states, num_key_value_groups, seq_len)

0.6810362339019775

In [15]:
bsz = 8
num_key_value_groups = 8
prefill = 122880
seq_len = torch.tensor([prefill]*bsz, dtype=torch.int32, device="cuda:1")

query_states = torch.randn(bsz, 1, 32, 128, dtype=torch.float16, device="cuda:1")
key_states = torch.randn(bsz, prefill+1, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")
value_states = torch.randn(bsz, prefill+1, 32//num_key_value_groups, 128, dtype=torch.float16, device="cuda:1")

benchmark_mqa_attn('flash', query_states, key_states, value_states, num_key_value_groups, seq_len)

1.401186466217041

In [8]:
key_states.shape

torch.Size([8, 122881, 4, 128])