In [None]:
# import all the dependency
%cd ..
import torch
import os
import transformers
from medusa.model.medusa_model import MedusaModel
import time
from contextlib import contextmanager
import numpy as np
from medusa.model.utils import *
from medusa.model.kv_cache import *
from medusa.model.medusa_choices import mc_sim_7b_63


In [None]:
# medusa inference forward time test function

@contextmanager
def timed(wall_times, key):
    start = time.time()
    torch.cuda.synchronize()
    yield
    torch.cuda.synchronize()
    end = time.time()
    elapsed_time = end - start
    wall_times[key].append(elapsed_time)

def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}

    with timed(wall_times, 'init'):
        if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices:
            # Load the cached medusa buffer
            medusa_buffers = model.medusa_buffers
        else:
            # Initialize the medusa buffer
            medusa_buffers = generate_medusa_buffers(
                medusa_choices, device=model.base_model.device
            )
        model.medusa_buffers = medusa_buffers
        model.medusa_choices = medusa_choices

        # Initialize the past key and value states
        if hasattr(model, "past_key_values"):
            past_key_values = model.past_key_values
            past_key_values_data = model.past_key_values_data
            current_length_data = model.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(model.base_model)
            model.past_key_values = past_key_values
            model.past_key_values_data = past_key_values_data
            model.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_medusa_mode(model)
        medusa_logits, logits = initialize_medusa(
                input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
        )
    new_token = 0

    for idx in range(max_steps):
        with timed(wall_times, 'medusa'):
            candidates, tree_candidates = generate_candidates(
                    medusa_logits,
                    logits,
                    medusa_buffers["tree_indices"],
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'tree'):
            medusa_logits, logits, outputs = tree_decoding(
                    model,
                    tree_candidates,
                    past_key_values,
                    medusa_buffers["medusa_position_ids"],
                    input_ids,
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'posterior'):
            best_candidate, accept_length = evaluate_posterior(
                    logits, candidates, temperature, posterior_threshold, posterior_alpha
                )

        with timed(wall_times, 'update'):
            input_ids, logits, medusa_logits, new_token = update_inference_inputs(
                    input_ids,
                    candidates,
                    best_candidate,
                    accept_length,
                    medusa_buffers["retrieve_indices"],
                    outputs,
                    logits,
                    medusa_logits,
                    new_token,
                    past_key_values_data,
                    current_length_data,
                )

        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break

    return input_ids, new_token, idx, wall_times

In [None]:
# load medusa model and set hyper-params
base_model = "/alg_vepfs/public/hqy/baichuan/wizardlm-10-16-dsl/checkpoint-100"
medusa_path = "/root/zhaliangyu/llm_inference_acc/Medusa/medusa_weights/wizardlm13b_medusa"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model = MedusaModel.from_pretrained(
    base_model=base_model,
    medusa_head_name_or_path=medusa_path,
    device_map="auto"
)
tokenizer = model.get_tokenizer()
temperature = 0.
posterior_threshold = 0.09
posterior_alpha = 0.3

In [None]:
PROMPT = "数据库的信息如下所示：{db_info},\n schema'和'detail'表示数据库的内容; 'foreign_keys'表示数据库多表之间的连接关系. " \
         "根据数据库信息以及用户的输入生成符合json格式输出的指令. \n\nUSER: {user_query}\nASSISTANT:"

db_info = """{"db_name": "Twenty-Five_Cents", "db_info": {"Twenty-Five_Cents": {"numeric_info": {"Year": [2006, 2007, 2007], "Issue price": [24.95, 24.95, 24.95]}, "categorical_info": {"Theme": ["Calgary Flames", "Edmonton Oilers", "Montreal Canadiens", "Ottawa Senators", "Toronto Maple Leafs", "Vancouver Canucks"], "Artist": ["N/A"], "Mintage": ["1264", "1634", "2213", "2952", "3527", "832", "N/A"]}, "date_cols_info": {}}}, "foreign_keys": []}"""
db_info = db_info.replace("{","{{").replace("}", "}}")
user_query = "不同主题的平均发行价格是多少？"
PROMPT = PROMPT.format(db_info=db_info,
                       user_query=user_query)
medusa_choices = mc_sim_7b_63

with torch.inference_mode():
    input_ids = tokenizer([PROMPT]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)

output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

In [None]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

In [None]:
cache_dir = "./"
model_max_length = 2048
device = "cuda"

config = transformers.AutoConfig.from_pretrained(
    base_model,
    cache_dir=cache_dir,
)

original_model = transformers.AutoModelForMaskedLM.from_pretrained(
    base_model,
    config=config,
    cache_dir=cache_dir,
    low_cpu_mem_usage=False,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

original_model.eval()

ori_tokenizer = transformers.AutoTokenizer.from_pretrained(
    base_model,
    config=config,
    model_max_length=model_max_length,
    padding_side="right",
    use_fast=False,
)

generate_configs = transformers.GenerationConfig(
    temperature=0.1,
    top_p=0.9,
    top_k=10
)

input_ids = ori_tokenizer([PROMPT]).input_ids.to(device)
start = time.time()
with torch.no_grad():
    output_ids = original_model.generate(
        input_ids=input_ids,
        generation_config=generate_configs,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=512,
    )
s = output_ids.sequences
new_token_num = len(s[0])
output = tokenizer.batch_decode(s, skip_special_tokens=True)
output = output[0].split("ASSISTANT:")[1].strip()
end = time.time()
original_model_total = end - start
print("Original generate total time:", original_model_total)
print("Original Tokens/second: ", original_model_total / new_token_num)