In [1]:
import os
import time

import numpy as np
import torch
import argparse

from datasets import load_dataset
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer, 
                          Phi3ForCausalLM,
                          Phi3Config,
                          pipeline,
                          )
from transformers.pipelines.pt_utils import KeyDataset

import utils

# os.environ["VLLM_USE_MODELSCOPE"] = "True"
# os.environ["TOKENIZERS_PARALLELISM"] = "False"
# os.environ["CUDA_MODULE_LOADING"] = "LAZY"

parser = argparse.ArgumentParser(description="vLLM + MInference")
parser.add_argument("--model_path", type=str, default="models/")
parser.add_argument("--model_name", type=str, default="Phi-3-medium-4k-instruct")
parser.add_argument("--data_path", type=str, default="data/")
parser.add_argument("--dataset", type=str, default="test_dataset.jsonl")
parser.add_argument('--seed',type=int, default=0)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument('--dtype',type=str, default="auto")
parser.add_argument('--gpu_memory_utilization',type=float, default=0.9) # 0.3, 0.5
parser.add_argument('--enforce_eager',type=bool, default=False)

parser.add_argument('--n', type=int, default=1)
parser.add_argument('--temperature', type=float, default=0.0)
parser.add_argument('--top_p', type=float, default=1.0)
parser.add_argument('--top_k', type=float, default=-1)
parser.add_argument('--ignore_eos', type=bool, default=True)
parser.add_argument('--max_tokens', type=int, default=500)

config = parser.parse_args([])
utils.seed_everything(config.seed)
CURR_PATH = os.getcwd()

In [None]:
#######  Set up  #######
from vllm import LLM, SamplingParams
model_id = os.path.join(CURR_PATH, config.model_path, config.model_name) # please replace with local model path

model_args = {
    "trust_remote_code": True,
    "gpu_memory_utilization": config.gpu_memory_utilization,
    "seed": config.seed,
    "dtype": config.dtype,
    "enforce_eager": config.enforce_eager,
}

model = LLM(model=model_id, **model_args)
tokenizer = AutoTokenizer.from_pretrained(model_id)

sampling_args = {
    "temperature": config.temperature,
    "top_p": config.top_p,
    "seed": config.seed,
    "use_beam_search":False,
    "ignore_eos": config.ignore_eos,
    "max_tokens": config.max_tokens,
}

In [None]:
####### Section 2. GPU Warm up #######
messages = [
    {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
    {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
    {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
]

token_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

sampling_params = SamplingParams(**sampling_args)
output = model.generate(token_ids, sampling_params)

print(output)

In [None]:
output.outputs[0].text

In [None]:
datafile_path = os.path.join(CURR_PATH, config.data_path, config.dataset)
data = load_dataset("json", data_files=datafile_path)['train']

In [None]:
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

In [None]:
####### Section 3. Load data and Inference -> Performance evaluation part #######
start = time.perf_counter()
# data = load_dataset("json", data_files=input_file_path)['train']
data = load_dataset("json", data_files=output_file_path)['train']
outs = pipe(KeyDataset(data, 'message'), **generation_args)
end = time.perf_counter()

In [None]:
####### Section 4. Accuracy (Just for leasderboard) #######
print("===== Answers =====")
correct = 0
for i, out in enumerate(outs):
    correct_answer = data[i]["answer"]
    answer = out[0]["generated_text"].lstrip().replace("\n","")
    if answer == correct_answer:
        correct += 1
    print(answer)
 
print("===== Perf result =====")
print("Elapsed_time: ", end-start)
print(f"Correctness: {correct}/{len(data)}")