Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce speculative decoding with draft models to vLLM #3029

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
draft_model=args.draft_model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
Expand All @@ -26,11 +27,13 @@ def main(args: argparse.Namespace):
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
use_flash_attn=args.use_flash_attn,
parallel_decoding_lookahead=args.parallel_decoding_lookahead,
)

sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
temperature=0.0 if args.use_beam_search else args.temperature,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
Expand Down Expand Up @@ -89,6 +92,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
Expand All @@ -103,6 +107,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument('--num-iters',
type=int,
default=3,
Expand Down Expand Up @@ -145,5 +150,14 @@ def run_to_completion(profile_dir: Optional[str] = None):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--use-flash-attn",
action="store_true",
help="Use flash attention (requires flash-attn >= 2.5.0).")
parser.add_argument(
"--parallel-decoding-lookahead",
type=int,
default=1,
help="Number of lookahead steps for speculativespeculative decoding.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double spell

args = parser.parse_args()
main(args)
33 changes: 26 additions & 7 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,27 @@ def sample_requests(
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
draft_model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
temperature: float,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
use_flash_attn: Optional[bool] = False,
parallel_decoding_lookahead: Optional[int] = 1,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
draft_model=draft_model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
Expand All @@ -87,13 +92,15 @@ def run_vllm(
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
use_flash_attn=use_flash_attn,
parallel_decoding_lookahead=parallel_decoding_lookahead,
)

# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=0.0 if use_beam_search else temperature,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
Expand Down Expand Up @@ -206,12 +213,13 @@ def main(args: argparse.Namespace):
args.output_len)

if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device)
elapsed_time = run_vllm(
requests, args.model, args.draft_model, args.tokenizer,
args.quantization, args.tensor_parallel_size, args.seed, args.n,
args.use_beam_search, args.temperature, args.trust_remote_code,
args.dtype, args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device, args.use_flash_attn,
args.parallel_decoding_lookahead)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -248,6 +256,7 @@ def main(args: argparse.Namespace):
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
Expand All @@ -259,6 +268,7 @@ def main(args: argparse.Namespace):
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--num-prompts",
type=int,
default=1000,
Expand Down Expand Up @@ -302,6 +312,15 @@ def main(args: argparse.Namespace):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--use-flash-attn",
action="store_true",
help="Use flash attention (requires flash-attn >= 2.5.0).")
parser.add_argument(
"--parallel-decoding-lookahead",
type=int,
default=1,
help="Number of lookahead steps for speculative decoding.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down