-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Add DeepSpeed MII backend to benchmark script #1649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,18 +6,21 @@ | |
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase | ||
from transformers import (AutoModelForCausalLM, AutoTokenizer, | ||
PreTrainedTokenizerBase) | ||
from tqdm import tqdm | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
|
||
def sample_requests( | ||
dataset_path: str, | ||
num_requests: int, | ||
tokenizer: PreTrainedTokenizerBase, | ||
fixed_output_len: Optional[int], | ||
) -> List[Tuple[str, int, int]]: | ||
if fixed_output_len is not None: | ||
if fixed_output_len < 4: | ||
raise ValueError("output_len too small") | ||
|
||
# Load the dataset. | ||
with open(dataset_path) as f: | ||
dataset = json.load(f) | ||
|
@@ -35,6 +38,8 @@ def sample_requests( | |
tokenized_dataset = [] | ||
for i in range(len(dataset)): | ||
output_len = len(completion_token_ids[i]) | ||
if fixed_output_len is not None: | ||
output_len = fixed_output_len | ||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) | ||
|
||
# Filter out too long sequences. | ||
|
@@ -66,6 +71,7 @@ def run_vllm( | |
trust_remote_code: bool, | ||
dtype: str, | ||
) -> float: | ||
from vllm import LLM, SamplingParams | ||
llm = LLM( | ||
model=model, | ||
tokenizer=tokenizer, | ||
|
@@ -160,14 +166,37 @@ def run_hf( | |
return end - start | ||
|
||
|
||
def run_mii( | ||
requests: List[Tuple[str, int, int]], | ||
model: str, | ||
tensor_parallel_size: int, | ||
output_len: int, | ||
) -> float: | ||
from mii import pipeline | ||
llm = pipeline(model, tensor_parallel=tensor_parallel_size) | ||
prompts = [prompt for prompt, _, _ in requests] | ||
|
||
start = time.perf_counter() | ||
llm(prompts, max_new_tokens=output_len) | ||
end = time.perf_counter() | ||
return end - start | ||
|
||
|
||
def main(args: argparse.Namespace): | ||
print(args) | ||
random.seed(args.seed) | ||
|
||
# Sample the requests. | ||
tokenizer = get_tokenizer(args.tokenizer, | ||
trust_remote_code=args.trust_remote_code) | ||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
args.tokenizer, trust_remote_code=args.trust_remote_code) | ||
if args.dataset is None: | ||
# Synthesize a prompt with the given input length. | ||
prompt = "hi" * (args.input_len - 1) | ||
requests = [(prompt, args.input_len, args.output_len) | ||
for _ in range(args.num_prompts)] | ||
else: | ||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer, | ||
args.output_len) | ||
|
||
if args.backend == "vllm": | ||
elapsed_time = run_vllm(requests, args.model, args.tokenizer, | ||
|
@@ -179,6 +208,9 @@ def main(args: argparse.Namespace): | |
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, | ||
args.use_beam_search, args.hf_max_batch_size, | ||
args.trust_remote_code) | ||
elif args.backend == "mii": | ||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, | ||
args.output_len) | ||
else: | ||
raise ValueError(f"Unknown backend: {args.backend}") | ||
total_num_tokens = sum(prompt_len + output_len | ||
|
@@ -191,12 +223,21 @@ def main(args: argparse.Namespace): | |
parser = argparse.ArgumentParser(description="Benchmark the throughput.") | ||
parser.add_argument("--backend", | ||
type=str, | ||
choices=["vllm", "hf"], | ||
choices=["vllm", "hf", "mii"], | ||
default="vllm") | ||
parser.add_argument("--dataset", | ||
type=str, | ||
required=True, | ||
default=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this line changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh it's because users can set the fixed input and output lengths instead of providing a dataset. |
||
help="Path to the dataset.") | ||
parser.add_argument("--input-len", | ||
type=int, | ||
default=None, | ||
help="Input prompt length for each request") | ||
parser.add_argument("--output-len", | ||
type=int, | ||
default=None, | ||
help="Output length for each request. Overrides the " | ||
"output length from the dataset.") | ||
parser.add_argument("--model", type=str, default="facebook/opt-125m") | ||
parser.add_argument("--tokenizer", type=str, default=None) | ||
parser.add_argument('--quantization', | ||
|
@@ -231,6 +272,13 @@ def main(args: argparse.Namespace): | |
'for FP32 and FP16 models, and BF16 precision ' | ||
'for BF16 models.') | ||
args = parser.parse_args() | ||
if args.tokenizer is None: | ||
args.tokenizer = args.model | ||
if args.dataset is None: | ||
assert args.input_len is not None | ||
assert args.output_len is not None | ||
else: | ||
assert args.input_len is None | ||
|
||
if args.backend == "vllm": | ||
if args.hf_max_batch_size is not None: | ||
|
@@ -240,7 +288,18 @@ def main(args: argparse.Namespace): | |
raise ValueError("HF max batch size is required for HF backend.") | ||
if args.quantization is not None: | ||
raise ValueError("Quantization is only for vLLM backend.") | ||
if args.tokenizer is None: | ||
args.tokenizer = args.model | ||
|
||
elif args.backend == "mii": | ||
if args.dtype != "auto": | ||
raise ValueError("dtype must be auto for MII backend.") | ||
if args.n != 1: | ||
raise ValueError("n must be 1 for MII backend.") | ||
if args.use_beam_search: | ||
raise ValueError("Beam search is not supported for MII backend.") | ||
if args.quantization is not None: | ||
raise ValueError("Quantization is only for vLLM backend.") | ||
if args.hf_max_batch_size is not None: | ||
raise ValueError("HF max batch size is only for HF backend.") | ||
if args.tokenizer != args.model: | ||
raise ValueError("Tokenizer must be the same as the model for MII " | ||
"backend.") | ||
main(args) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this line be
" ".join(["hi"] * args.input_len)
? In general, how can you make sure the prompt you generate has the number of tokens you specified with a bunch of "hi"s?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree it's a bit hacky. However, I found this worked for LLaMA and OPT because "hi" is a single token in their tokenizers and "hi" * n is split into n "hi" tokens.