Skip to content

Commit

Permalink
Speculative decoding with draft model
Browse files Browse the repository at this point in the history
Signed-off-by: Tao He <sighingnow@gmail.com>
  • Loading branch information
sighingnow committed Feb 26, 2024
1 parent 776b60b commit 1642fa3
Show file tree
Hide file tree
Showing 21 changed files with 1,019 additions and 192 deletions.
11 changes: 10 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
draft_model=args.draft_model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
Expand All @@ -27,11 +28,12 @@ def main(args: argparse.Namespace):
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
use_flash_attn=args.use_flash_attn,
parallel_decoding_lookahead=args.parallel_decoding_lookahead,
)

sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
temperature=0.0 if args.use_beam_search else args.temperature,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
Expand Down Expand Up @@ -90,6 +92,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
Expand All @@ -104,6 +107,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument('--num-iters',
type=int,
default=3,
Expand Down Expand Up @@ -150,5 +154,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
"--use-flash-attn",
action="store_true",
help="Use flash attention (requires flash-attn >= 2.5.0).")
parser.add_argument(
"--parallel-decoding-lookahead",
type=int,
default=1,
help="Number of lookahead steps for speculativespeculative decoding.")
args = parser.parse_args()
main(args)
28 changes: 20 additions & 8 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,27 @@ def sample_requests(
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
draft_model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
temperature: float,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
use_flash_attn: Optional[bool] = False,
parallel_decoding_lookahead: Optional[int] = 1,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
draft_model=draft_model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
Expand All @@ -89,13 +93,14 @@ def run_vllm(
kv_cache_dtype=kv_cache_dtype,
device=device,
use_flash_attn=use_flash_attn,
parallel_decoding_lookahead=parallel_decoding_lookahead,
)

# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=0.0 if use_beam_search else temperature,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
Expand Down Expand Up @@ -208,13 +213,13 @@ def main(args: argparse.Namespace):
args.output_len)

if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device,
args.use_flash_attn)
elapsed_time = run_vllm(
requests, args.model, args.draft_model, args.tokenizer,
args.quantization, args.tensor_parallel_size, args.seed, args.n,
args.use_beam_search, args.temperature, args.trust_remote_code,
args.dtype, args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device, args.use_flash_attn,
args.parallel_decoding_lookahead)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -251,6 +256,7 @@ def main(args: argparse.Namespace):
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
Expand All @@ -262,6 +268,7 @@ def main(args: argparse.Namespace):
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--num-prompts",
type=int,
default=1000,
Expand Down Expand Up @@ -309,6 +316,11 @@ def main(args: argparse.Namespace):
"--use-flash-attn",
action="store_true",
help="Use flash attention (requires flash-attn >= 2.5.0).")
parser.add_argument(
"--parallel-decoding-lookahead",
type=int,
default=1,
help="Number of lookahead steps for speculative decoding.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ aioprometheus[starlette]
pynvml == 11.5.0
triton >= 2.1.0
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
packaging
flash-attn >= 2.5.0
2 changes: 1 addition & 1 deletion tests/worker/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def create_worker(cls: type,
)

(model_config, cache_config, parallel_config, scheduler_config,
device_config, _) = engine_args.create_engine_configs()
device_config, _, _) = engine_args.create_engine_configs()

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
Expand Down
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ class SchedulerConfig:
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
parallel_decoding_lookahead: Number of tokens to look ahead for parallel decoding.
"""

def __init__(
Expand All @@ -455,6 +456,7 @@ def __init__(
max_num_seqs: int,
max_model_len: int,
max_paddings: int,
parallel_decoding_lookahead: Optional[int] = 1,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -465,6 +467,7 @@ def __init__(
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.max_paddings = max_paddings
self.parallel_decoding_lookahead = parallel_decoding_lookahead
self._verify_args()

def _verify_args(self) -> None:
Expand Down
11 changes: 11 additions & 0 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,17 @@ def can_append_slot(self, seq_group: SequenceGroup) -> bool:
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks

def can_append_slots(self,
seq_group: SequenceGroup,
reserve: Optional[int] = 1) -> bool:
# Simple heuristic: as the maximum possible parallel decoding lookahead
# is 8 (less than block size), if there is at least one free block for
# each sequence, we can append.
assert reserve <= self.block_size, f"Expect reserve <= block_size, got {reserve} > {self.block_size}"
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks

def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
Expand Down
48 changes: 46 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
parallel_decoding_lookahead: Optional[int] = 1,
) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
Expand All @@ -49,6 +50,7 @@ def __init__(
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
self.parallel_decoding_lookahead = parallel_decoding_lookahead

self.num_loras = len(self.lora_requests)
if self.num_loras > 0:
Expand All @@ -69,6 +71,17 @@ def _sort_by_lora_ids(self) -> bool:
def lora_requests(self) -> Set[LoRARequest]:
return {g.lora_request for g in self.scheduled_seq_groups}

def __str__(self) -> str:
return (
f"SchedulerOutputs(scheduled_seq_groups={self.scheduled_seq_groups}, "
f"prompt_run={self.prompt_run}, "
f"num_batched_tokens={self.num_batched_tokens}, "
f"blocks_to_swap_in={self.blocks_to_swap_in}, "
f"blocks_to_swap_out={self.blocks_to_swap_out}, "
f"blocks_to_copy={self.blocks_to_copy}, "
f"ignored_seq_groups={self.ignored_seq_groups}, "
f"parallel_decoding_lookahead={self.parallel_decoding_lookahead})")


class Scheduler:

Expand Down Expand Up @@ -279,7 +292,9 @@ def _schedule(self) -> SchedulerOutputs:
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.popleft()
while not self.block_manager.can_append_slot(seq_group):
while not self.block_manager.can_append_slots(
seq_group,
reserve=self.scheduler_config.parallel_decoding_lookahead):
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop()
Expand All @@ -293,6 +308,9 @@ def _schedule(self) -> SchedulerOutputs:
break
else:
# Append new slots to the sequence group.
self._reserve_logical_slots(seq_group,
lookahead=self.scheduler_config.
parallel_decoding_lookahead)
self._append_slot(seq_group, blocks_to_copy)
running.append(seq_group)
self.running = running
Expand Down Expand Up @@ -336,6 +354,9 @@ def _schedule(self) -> SchedulerOutputs:
curr_loras.add(lora_int_id)
self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._reserve_logical_slots(seq_group,
lookahead=self.scheduler_config.
parallel_decoding_lookahead)
self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
self.running.append(seq_group)
Expand All @@ -349,14 +370,29 @@ def _schedule(self) -> SchedulerOutputs:
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)

lookahead = 1
for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
k = self.scheduler_config.parallel_decoding_lookahead
k = min(k, self.scheduler_config.max_model_len - seq.get_len())
if seq_group.sampling_params.max_tokens:
k = min(
k, seq_group.sampling_params.max_tokens -
seq.get_output_len())
lookahead = max(lookahead, k)

if lookahead > 1:
num_batched_tokens *= lookahead

scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
scheduled_seq_groups=running,
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=[],
parallel_decoding_lookahead=lookahead,
)
return scheduler_outputs

Expand Down Expand Up @@ -388,6 +424,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
state=seq_group.state,
parallel_decoding_lookahead=scheduler_outputs.
parallel_decoding_lookahead,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
Expand All @@ -407,6 +445,12 @@ def _allocate(self, seq_group: SequenceGroup) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING

def _reserve_logical_slots(self,
seq_group: SequenceGroup,
lookahead: int = 1) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.reserve_logical_blocks(lookahead - 1)

def _append_slot(
self,
seq_group: SequenceGroup,
Expand Down

0 comments on commit 1642fa3

Please sign in to comment.