In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '8'
import sys
root_dir = "/home/hanshis/workspace/LongContextInfer"
sys.path.append(root_dir)
import torch
import time
import argparse
import math
from tqdm import tqdm
import socket

from models.modeling_llama import LlamaForCausalLM, LlamaConfig
from models.cache_utils import SimpleCache, FlashSimpleCache, GraphFlashSimpleCache
from utils.graph_infer import GraphInferenceEngine

PREFIX_LEN = 1000
T = 100
WARM_UP = 10

host = socket.gethostname()
if 'lovelace' in host:
    file_path = "/home/hanshis/workspace/LongContextInfer/benchmark/report/L40_llama_7B_128K_graph.csv"
else:
    file_path = "/fsx-storygen/beidic/hanshi/LongContextInfer/benchmark/report/A100_llama_7B_128K_graph.csv"

try:
    with open(file_path, 'r') as f:
        contents = f.read()
except FileNotFoundError:
    contents = ""

if not contents:
    with open(file_path, 'a') as f:
        f.write("model,prefill,len,latency,repeat_time,flash\n")

model_name = "NousResearch/Yarn-Llama-2-7b-128k"
config = LlamaConfig.from_pretrained(model_name)
config.flash = True
if config.max_position_embeddings < 4096:
    config.max_position_embeddings = 1024*128
model = LlamaForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16, device_map="auto")

# DEC_LEN_LIST = [1,2,4,8,16,32,48,64,80,96,112,128,144,160,176,192,208,224,240,256,272,288,304,320,336,352,368,384,400,416,432,448,464,480,496,512]

DEC_LEN_LIST = [1]

MAX_LEN = PREFIX_LEN + 512

cache = FlashSimpleCache(model, MAX_LEN)
graph_cache = GraphFlashSimpleCache(model, MAX_LEN)

for DEC_LEN in DEC_LEN_LIST:
    cache.reset()
    graph_cache.reset()
    prefix = torch.randint(low=3, high=30000, size=(1, PREFIX_LEN), device=model.device)
    assert prefix.shape[-1] == PREFIX_LEN

    graph_engine = GraphInferenceEngine(model, cache, graph_cache)
    graph_engine.initialize_cuda_graph([DEC_LEN])

    graph_engine.inference(input_ids=prefix)

    cache.print_status()
    graph_cache.print_status()

>>>> Flash Attention installed
>>>> Flash RoPE installed


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Cached Size: 1000 | Max Budget: 1512
Max Budget: 1512


In [2]:
cache.key_cache[1][0][1000]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [3]:
graph_cache.key_cache[0][0][0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [18]:
torch.allclose(graph_cache.key_cache[0][0][1000], torch.zeros_like(graph_cache.key_cache[0][0][1000]))

False

In [5]:
cache.key_cache[2][0][1000]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [15]:
graph_cache.key_cache[0][0][1000]

tensor([[-0.0194, -0.1855,  0.5117,  ...,  0.3369,  0.0435,  0.4014],
        [ 0.9253,  0.4568, -0.7075,  ..., -0.7070,  0.7300, -0.6943],
        [-0.5557, -0.6553, -0.3027,  ..., -0.0223, -0.1632, -0.1783],
        ...,
        [-0.1730,  0.0549, -0.0212,  ..., -0.3787,  0.4285, -0.3359],
        [ 0.2054, -0.2064, -2.2402,  ...,  0.4592, -0.2316, -0.2203],
        [-1.5889, -1.3701, -1.8203,  ...,  1.6953, -0.5571,  0.4343]],
       device='cuda:0', dtype=torch.float16)

In [7]:
input_ids = torch.randint(low=3, high=30000, size=(1, DEC_LEN), device=model.device)
storage_ids = torch.arange(DEC_LEN, device=model.device) + PREFIX_LEN
graph_engine.graph_inference(input_ids=input_ids, storage_ids=storage_ids)

tensor([[[-6.1914, -4.5703,  2.4707,  ..., -3.9902, -1.9297, -2.4102]]],
       device='cuda:0')

In [8]:
graph_engine

<utils.graph_infer.GraphInferenceEngine at 0x7fb4f6420af0>

In [9]:
# input_ids = torch.randint(low=3, high=30000, size=(1, DEC_LEN), device=model.device)
# storage_ids = torch.arange(DEC_LEN, device=model.device) + PREFIX_LEN
# for _ in range(WARM_UP):
#     graph_engine.graph_inference(input_ids=input_ids, storage_ids=storage_ids)

# cache.print_status()
# graph_cache.print_status()

In [1]:
# Load model directly
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import sys
root_dir = '/home/hanshis/workspace/LongContextInfer'
sys.path.append(root_dir)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from models.modeling_llama_flash import LlamaForCausalLM
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Yarn-Llama-2-7B-128K-GPTQ", trust_remote_code=True)
model = LlamaForCausalLM.from_pretrained("TheBloke/Yarn-Llama-2-7B-128K-GPTQ", revision="gptq-4bit-32g-actorder_True", device_map="cuda:9")

>>>> Flash Attention installed
>>>> Flash RoPE installed


CUDA extension not installed.
CUDA extension not installed.


In [4]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_VISIBLE_DEVICES'] = '9'

import sys
root_dir = '/home/hanshis/workspace/LongContextInfer'
sys.path.append(root_dir)
import torch

from transformers import AutoTokenizer, TextGenerationPipeline
# from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from models.llama_gptq import LlamaGPTQForCausalLM


model = LlamaGPTQForCausalLM.from_quantized("TheBloke/Yarn-Llama-2-7B-128K-GPTQ", device_map='auto')

1. You disabled CUDA extensions compilation by setting BUILD_CUDA_EXT=0 when install auto_gptq from source.
2. You are using pytorch without CUDA support.
3. CUDA and nvcc are not installed in your device.
INFO - The layer lm_head is not quantized.


  0%|          | 0/1187 [00:00<?, ?w/s]

Skipping module injection for FusedLlamaMLPForQuantizedModel as currently not supported with use_triton=False.
