In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '8'
import sys
root_dir = "/home/hanshis/workspace/LongContextInfer"
sys.path.append(root_dir)
import torch
import time
import argparse
import math
from tqdm import tqdm
import socket

from models.modeling_llama import LlamaForCausalLM, LlamaConfig
from models.cache_utils import SimpleCache, FlashSimpleCache, GraphFlashSimpleCache, GraphFlashStreamLLMCache
from utils.graph_infer import GraphInferenceEngine

PREFIX_LEN = 1000
T = 100
WARM_UP = 10

model_name = "NousResearch/Yarn-Llama-2-7b-128k"
config = LlamaConfig.from_pretrained(model_name)
config.flash = True
if config.max_position_embeddings < 4096:
    config.max_position_embeddings = 1024*128
model = LlamaForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16, device_map="auto")

DEC_LEN = 6
MAX_LEN = PREFIX_LEN + DEC_LEN

cache = FlashSimpleCache(model, MAX_LEN)
graph_cache = GraphFlashStreamLLMCache(model, max_budget=1000, prefill=PREFIX_LEN, gamma=DEC_LEN)

cache.reset()
graph_cache.reset()

prefix = torch.randint(low=3, high=30000, size=(1, PREFIX_LEN), device=model.device)
assert prefix.shape[-1] == PREFIX_LEN

graph_engine = GraphInferenceEngine(model, cache, graph_cache)
graph_engine.initialize_cuda_graph([DEC_LEN])

# prefill
graph_engine.inference(input_ids=prefix)
graph_engine.init_graph_cache()

>>>> Flash Attention installed
>>>> Flash RoPE installed


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

capturing graph...
capturing graph...
capturing graph...
capturing graph...
capturing graph...
capturing graph...


In [3]:
for i in range(32):
    assert torch.allclose(graph_engine.engine.graph_cache.key_cache[i][:,:16], graph_engine.engine.kv_cache.key_cache[i][:,:16]), f"{i}"
    assert torch.allclose(graph_engine.engine.graph_cache.key_cache[i][:,16:1000], graph_engine.engine.kv_cache.key_cache[i][:,PREFIX_LEN-984:PREFIX_LEN])

In [4]:
for gamma_offset in range(6):
    input_ids = torch.randint(low=3, high=30000, size=(1, 1), device=model.device)
    storage_ids = torch.tensor([graph_engine.engine.graph_cache.max_budget + gamma_offset], device=model.device)
    position_ids = torch.tensor([graph_engine.engine.graph_cache.max_budget + gamma_offset], device=model.device).unsqueeze(0)
    print(input_ids, storage_ids, position_ids, gamma_offset)

    # print(graph_engine.graph_inference(input_ids=input_ids, storage_ids=storage_ids, position_ids=position_ids, gamma_offset = gamma_offset))
    # print(graph_engine.inference(input_ids=input_ids))

    print(f"Verifying cache consistency, on {cache.seq_len}")
    for i in range(32):
        assert torch.allclose(graph_engine.engine.graph_cache.key_cache[i], graph_engine.engine.kv_cache.key_cache[i]), f"{i}"
        assert torch.allclose(graph_engine.engine.graph_cache.key_cache[i], graph_engine.engine.kv_cache.key_cache[i])

tensor([[22531]], device='cuda:0') tensor([1000], device='cuda:0') tensor([[1000]], device='cuda:0') 0
Verifying cache consistency, on 1000
tensor([[3889]], device='cuda:0') tensor([1001], device='cuda:0') tensor([[1001]], device='cuda:0') 1
Verifying cache consistency, on 1000
tensor([[15555]], device='cuda:0') tensor([1002], device='cuda:0') tensor([[1002]], device='cuda:0') 2
Verifying cache consistency, on 1000
tensor([[26835]], device='cuda:0') tensor([1003], device='cuda:0') tensor([[1003]], device='cuda:0') 3
Verifying cache consistency, on 1000
tensor([[16372]], device='cuda:0') tensor([1004], device='cuda:0') tensor([[1004]], device='cuda:0') 4
Verifying cache consistency, on 1000
tensor([[23916]], device='cuda:0') tensor([1005], device='cuda:0') tensor([[1005]], device='cuda:0') 5
Verifying cache consistency, on 1000


In [4]:
graph_cache.key_cache[1][:,1001]

tensor([[[-1.5273e+00,  5.3809e-01, -1.7080e+00,  ..., -1.6025e+00,
          -1.6680e+00,  1.4502e+00],
         [ 9.7754e-01, -5.9229e-01,  1.3574e+00,  ..., -1.2910e+00,
          -1.1514e+00, -8.9111e-01],
         [-2.2021e-01, -3.5583e-02,  3.0200e-01,  ...,  1.7744e+00,
           1.0371e+00,  1.2500e+00],
         ...,
         [-2.2675e-02, -4.4098e-02, -2.2217e-01,  ..., -2.6880e-01,
          -1.5926e-03, -2.2354e-02],
         [-1.1035e+00, -3.0103e-01, -1.7957e-01,  ...,  7.5439e-01,
          -9.7021e-01,  3.5596e-01],
         [ 1.0293e+00, -9.9426e-02,  2.1631e-01,  ...,  2.1309e+00,
          -2.1738e+00,  1.5518e+00]]], device='cuda:0', dtype=torch.float16)

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '8'
import sys
root_dir = "/home/hanshis/workspace/LongContextInfer"
sys.path.append(root_dir)
import torch
import time
import argparse
import math
from tqdm import tqdm
import socket

from models.modeling_llama import LlamaForCausalLM, LlamaConfig
from models.cache_utils import SimpleCache, FlashSimpleCache, GraphFlashSimpleCache, GraphFlashStreamLLMCache
from utils.graph_infer import GraphInferenceEngine

PREFIX_LEN = 1000
T = 100
WARM_UP = 10

model_name = "NousResearch/Yarn-Llama-2-7b-128k"
config = LlamaConfig.from_pretrained(model_name)
config.flash = True
if config.max_position_embeddings < 4096:
    config.max_position_embeddings = 1024*128
model = LlamaForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16, device_map="auto")

DEC_LEN = 6
MAX_LEN = PREFIX_LEN + DEC_LEN

cache = FlashSimpleCache(model, MAX_LEN)
graph_cache = GraphFlashStreamLLMCache(model, max_budget=1000, prefill=PREFIX_LEN, gamma=DEC_LEN)

cache.reset()
graph_cache.reset()

prefix = torch.randint(low=3, high=30000, size=(1, PREFIX_LEN), device=model.device)
assert prefix.shape[-1] == PREFIX_LEN

graph_engine = GraphInferenceEngine(model, cache, graph_cache)
graph_engine.initialize_cuda_graph()

# prefill
graph_engine.inference(input_ids=prefix)
# graph_engine.init_graph_cache()

>>>> Flash Attention installed
>>>> Flash RoPE installed


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

capturing graph...
capturing graph...
capturing graph...
capturing graph...
capturing graph...
capturing graph...


tensor([[[ -9.5703, -10.8281,  -1.1270,  ..., -11.2422,  -9.7578,  -9.1797],
         [-10.6406, -10.2266,   0.2112,  ...,  -6.4961,  -7.9102,  -8.6406],
         [-12.8906, -13.6328,  -3.1113,  ...,  -8.7734,  -9.5156,  -8.8984],
         ...,
         [ -5.1562,  -3.8086,   1.4199,  ...,  -2.3457,  -1.5586,  -2.5527],
         [ -5.3203,  -3.6309,   1.7793,  ...,  -1.9453,  -1.2783,  -3.1035],
         [ -5.7461,  -5.1328,   1.1377,  ...,  -2.2402,  -1.7432,  -3.1699]]],
       device='cuda:0')

In [13]:
for i in range(32):
    assert torch.allclose(graph_engine.engine.graph_cache.key_cache[i][:,:16], graph_engine.engine.kv_cache.key_cache[i][:,:16]), f"{i}"
    assert torch.allclose(graph_engine.engine.graph_cache.key_cache[i][:,16:1000], graph_engine.engine.kv_cache.key_cache[i][:,PREFIX_LEN-984:PREFIX_LEN])

1006

In [18]:
-graph_cache.recent_size + cache.seq_len

22

In [19]:
cache.seq_len

1006

In [20]:
cache.key_cache[:].shape

torch.Size([32, 1, 1006, 32, 128])

In [17]:
cache.key_cache[:][:, -graph_cache.recent_size + cache.seq_len:cache.seq_len]

tensor([], device='cuda:0', size=(32, 0, 1006, 32, 128), dtype=torch.float16)

In [13]:
graph_cache.key_cache[:,:,994,0]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.float16)

In [6]:
cache.key_cache[:,:,1000:] = 0

In [5]:
graph_cache.key_cache[i][:].shape

torch.Size([1, 1006, 32, 128])

In [15]:
graph_cache.print_status()
cache.print_status()

Max Budget: 1000  | Real Budget: 1006  | PreFill: 1000  | Start Size: 16  | Recent Size: 984
Cached Size: 1001 | Max Budget: 1006


In [19]:
graph_cache.key_cache[0,0,1001]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [20]:
cache.key_cache[0][0,1001]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [21]:
for i in range(32):
    assert torch.allclose(graph_cache.key_cache[i][:,1001], cache.key_cache[i][:,1001]), f"{i}"
    assert torch.allclose(graph_cache.key_cache[i][:,1001], cache.key_cache[i][:,1001])

In [13]:
graph_engine.graph_inference(input_ids=input_ids, storage_ids=storage_ids, position_ids=position_ids, gamma_offset = gamma_offset)

tensor([[[-5.5820, -1.0117,  2.8711,  ..., -1.5615, -1.1182, -1.6680]]],
       device='cuda:0')

In [14]:
graph_engine.inference(input_ids=input_ids)

tensor([[[-5.9414, -1.1523,  2.8789,  ..., -1.5420, -1.2578, -2.1348]]],
       device='cuda:0')

In [5]:
graph_engine.graph_inference(input_ids=input_ids, storage_ids=storage_ids, position_ids=position_ids)

tensor([[[-9.6484, -3.3320, -0.1942,  ..., -7.1016, -6.2188, -5.1758]]],
       device='cuda:0')

In [None]:
graph_engine.graph_inference(input_ids=input_ids, storage_ids=storage_ids, position_ids=position_ids)

graph_cache.update_stream_cache(kv_cache=cache)

In [9]:
a = torch.tensor([float('-inf')], dtype=torch.float16, device=model.device)

In [13]:
a*-1

tensor([inf], device='cuda:0', dtype=torch.float16)

In [3]:
graph_cache.key_cache[0][0][0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [18]:
torch.allclose(graph_cache.key_cache[0][0][1000], torch.zeros_like(graph_cache.key_cache[0][0][1000]))

False

In [5]:
cache.key_cache[2][0][1000]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [15]:
graph_cache.key_cache[0][0][1000]

tensor([[-0.0194, -0.1855,  0.5117,  ...,  0.3369,  0.0435,  0.4014],
        [ 0.9253,  0.4568, -0.7075,  ..., -0.7070,  0.7300, -0.6943],
        [-0.5557, -0.6553, -0.3027,  ..., -0.0223, -0.1632, -0.1783],
        ...,
        [-0.1730,  0.0549, -0.0212,  ..., -0.3787,  0.4285, -0.3359],
        [ 0.2054, -0.2064, -2.2402,  ...,  0.4592, -0.2316, -0.2203],
        [-1.5889, -1.3701, -1.8203,  ...,  1.6953, -0.5571,  0.4343]],
       device='cuda:0', dtype=torch.float16)

In [7]:
input_ids = torch.randint(low=3, high=30000, size=(1, DEC_LEN), device=model.device)
storage_ids = torch.arange(DEC_LEN, device=model.device) + PREFIX_LEN
graph_engine.graph_inference(input_ids=input_ids, storage_ids=storage_ids)

tensor([[[-6.1914, -4.5703,  2.4707,  ..., -3.9902, -1.9297, -2.4102]]],
       device='cuda:0')

In [8]:
graph_engine

<utils.graph_infer.GraphInferenceEngine at 0x7fb4f6420af0>

In [9]:
# input_ids = torch.randint(low=3, high=30000, size=(1, DEC_LEN), device=model.device)
# storage_ids = torch.arange(DEC_LEN, device=model.device) + PREFIX_LEN
# for _ in range(WARM_UP):
#     graph_engine.graph_inference(input_ids=input_ids, storage_ids=storage_ids)

# cache.print_status()
# graph_cache.print_status()

In [1]:
# Load model directly
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import sys
root_dir = '/home/hanshis/workspace/LongContextInfer'
sys.path.append(root_dir)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from models.modeling_llama_flash import LlamaForCausalLM
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Yarn-Llama-2-7B-128K-GPTQ", trust_remote_code=True)
model = LlamaForCausalLM.from_pretrained("TheBloke/Yarn-Llama-2-7B-128K-GPTQ", revision="gptq-4bit-32g-actorder_True", device_map="cuda:9")

>>>> Flash Attention installed
>>>> Flash RoPE installed


CUDA extension not installed.
CUDA extension not installed.


In [4]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_VISIBLE_DEVICES'] = '9'

import sys
root_dir = '/home/hanshis/workspace/LongContextInfer'
sys.path.append(root_dir)
import torch

from transformers import AutoTokenizer, TextGenerationPipeline
# from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from models.llama_gptq import LlamaGPTQForCausalLM


model = LlamaGPTQForCausalLM.from_quantized("TheBloke/Yarn-Llama-2-7B-128K-GPTQ", device_map='auto')

1. You disabled CUDA extensions compilation by setting BUILD_CUDA_EXT=0 when install auto_gptq from source.
2. You are using pytorch without CUDA support.
3. CUDA and nvcc are not installed in your device.
INFO - The layer lm_head is not quantized.


  0%|          | 0/1187 [00:00<?, ?w/s]

Skipping module injection for FusedLlamaMLPForQuantizedModel as currently not supported with use_triton=False.


In [None]:
from 

In [None]:
import torch
import numpy as np
from sklearn.cluster import KMeans

def rerange_kv_cache(kv_cache, chunk_size):

    num_clusters = kv_cache.seq_len // chunk_size
    assert num_clusters * chunk_size == kv_cache.seq_len, "max_budget should be divisible by chunk_size"

    for layer in range(kv_cache.layers):
        for head_index in range(kv_cache.num_heads):
            # (bsz, max_budget, head_dim) --> (bsz * max_budget, head_dim)
            head_key_cache = kv_cache.key_cache[layer][:, :, head_index, :].reshape(-1, kv_cache.head_dim).numpy() 
            head_value_cache = kv_cache.value_cache[layer][:, :, head_index, :].reshape(-1, kv_cache.head_dim).numpy()
            
            kmeans = KMeans(n_clusters=num_clusters, random_state=head_index).fit(head_key_cache)
            
            labels = kmeans.labels_
            sorted_indices = np.argsort(labels)
            sorted_head_key = head_key_cache[sorted_indices]
            sorted_head_value = head_value_cache[sorted_indices]

        
            kv_cache.key_cache[layer][:, :, head_index, :] = torch.tensor(sorted_head_key).reshape(1, kv_cache.seq_len, kv_cache.head_dim)
            kv_cache.value_cache[layer][:, :, head_index, :] = torch.tensor(sorted_head_value).reshape(1, kv_cache.seq_len, kv_cache.head_dim)
    print("Rerange KV cache complete")

In [27]:
import torch
import numpy as np
from sklearn.cluster import KMeans

# 设置参数
num_heads = 8
max_budget = 8 * 16  # 每个头有16个KV对，总共8个头
head_dim = 64
num_clusters = 8  # 总共8个聚类，每个聚类对应于一个头的16个KV对
chunk_size = 16  # 每个聚类包含16个KV对

# 模拟的KV缓存数据
key_cache = torch.rand(1, max_budget, num_heads, head_dim)
value_cache = torch.rand(1, max_budget, num_heads, head_dim)

# 初始化一个空的tensor用于存放重排列后的KV缓存
clustered_key_cache = torch.empty(1, max_budget, num_heads, head_dim)
clustered_value_cache = torch.empty(1, max_budget, num_heads, head_dim)

# 对每个头独立进行聚类
for head_index in range(num_heads):
    # 提取当前头的KV缓存
    head_key_cache = key_cache[:, :, head_index, :].reshape(-1, head_dim).numpy() # (bsz, max_budget, head_dim) --> (bsz * max_budget, head_dim)
    head_value_cache = value_cache[:, :, head_index, :].reshape(-1, head_dim).numpy()
    
    # 执行KMeans聚类
    kmeans = KMeans(n_clusters=num_clusters, random_state=head_index).fit(head_key_cache)
    
    # 获取聚类标签并根据这些标签排序KV缓存
    labels = kmeans.labels_
    sorted_indices = np.argsort(labels)
    sorted_head_key = head_key_cache[sorted_indices]
    sorted_head_value = head_value_cache[sorted_indices]

    
    # 将排序后的KV缓存重新放入对应的头中
    # print(torch.tensor(sorted_head_kv).shape, clustered_kv_cache[:, :, head_index, :].shape)
    clustered_key_cache[:, :, head_index, :] = torch.tensor(sorted_head_key).reshape(1, max_budget, head_dim)
    clustered_value_cache[:, :, head_index, :] = torch.tensor(sorted_head_value).reshape(1, max_budget, head_dim)

# 验证最终形状
print("最终的KV缓存形状:", clustered_key_cache.shape, clustered_value_cache.shape)


最终的KV缓存形状: torch.Size([1, 128, 8, 64]) torch.Size([1, 128, 8, 64])


  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)


In [28]:
# 假设 kv_cache 是原始的KV缓存，clustered_kv_cache 是重排列后的KV缓存

# 扁平化并排序两个KV缓存
original_flat_sorted = torch.sort(key_cache.reshape(-1))[0]
new_flat_sorted = torch.sort(clustered_key_cache.reshape(-1))[0]

# 比较两个排序后的扁平化KV缓存是否完全相同
are_elements_identical = torch.all(original_flat_sorted == new_flat_sorted)

original_flat_sorted = torch.sort(value_cache.reshape(-1))[0]
new_flat_sorted = torch.sort(clustered_value_cache.reshape(-1))[0]

# 比较两个排序后的扁平化KV缓存是否完全相同
are_elements_identical = torch.all(original_flat_sorted == new_flat_sorted)

print("原始KV缓存中的元素是否与重排列后的KV缓存中的完全相同？", are_elements_identical.item())


原始KV缓存中的元素是否与重排列后的KV缓存中的完全相同？ True


In [29]:
original_key_cache = key_cache  # 使用之前示例中的模拟数据
original_value_cache = value_cache  # 使用之前示例中的模拟数据

# 假设 rearranged_kv_cache 是重排列后的KV缓存，其形状为 [1, 128, num_heads, head_dim]
rearranged_key_cache = clustered_key_cache 

for head in range(num_heads):
    original_head_data = original_key_cache[:, :, head, :].reshape(-1, head_dim)
    rearranged_head_data = rearranged_key_cache[:, :, head, :].reshape(-1, head_dim)
    
    # 因为重排列改变了元素的顺序，我们不能直接比较对应位置的元素
    # 相 stat代，我们检查重排列后的每个元素是否在原始数据中
    found_all_elements = True
    for elem in rearranged_head_data:
        if not any(torch.allclose(elem, original_elem) for original_elem in original_head_data):
            found_all_elements = False
            break
    
    if found_all_elements:
        print(f"All elements in head {head} are correctly rearranged.")
    else:
        print(f"Not all elements in head {head} are found in the rearranged KV cache.")

original_key_cache = key_cache  # 使用之前示例中的模拟数据

# 假设 rearranged_kv_cache 是重排列后的KV缓存，其形状为 [1, 128, num_heads, head_dim]
rearranged_value_cache = clustered_value_cache 

for head in range(num_heads):
    original_head_data = original_value_cache[:, :, head, :].reshape(-1, head_dim)
    rearranged_head_data = rearranged_value_cache[:, :, head, :].reshape(-1, head_dim)
    
    # 因为重排列改变了元素的顺序，我们不能直接比较对应位置的元素
    # 相 stat代，我们检查重排列后的每个元素是否在原始数据中
    found_all_elements = True
    for elem in rearranged_head_data:
        if not any(torch.allclose(elem, original_elem) for original_elem in original_head_data):
            found_all_elements = False
            break
    
    if found_all_elements:
        print(f"All elements in head {head} are correctly rearranged.")
    else:
        print(f"Not all elements in head {head} are found in the rearranged KV cache.")

All elements in head 0 are correctly rearranged.
All elements in head 1 are correctly rearranged.
All elements in head 2 are correctly rearranged.
All elements in head 3 are correctly rearranged.
All elements in head 4 are correctly rearranged.
All elements in head 5 are correctly rearranged.
All elements in head 6 are correctly rearranged.
All elements in head 7 are correctly rearranged.
All elements in head 0 are correctly rearranged.
All elements in head 1 are correctly rearranged.
All elements in head 2 are correctly rearranged.
All elements in head 3 are correctly rearranged.
All elements in head 4 are correctly rearranged.
All elements in head 5 are correctly rearranged.
All elements in head 6 are correctly rearranged.
All elements in head 7 are correctly rearranged.


In [7]:
import torch

def capture_graph(mempool=None, n_warmups :int=3):
    device = "cuda:0"
    # draft run is incremental decoding
    static_position_ids = torch.tensor([[100]], device=device)
    
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(n_warmups):
            out = torch.arange(0, static_position_ids[0, 0] + 1, device=device)
        s.synchronize()
    torch.cuda.current_stream().wait_stream(s)

    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, pool=mempool):
        out = torch.arange(0, static_position_ids[0, 0] + 1, device=device)
    
    def run(position_ids):
        static_position_ids.copy_(position_ids)
        graph.replay()
        return out.clone()

    return run