Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Speculative Decoding #1797

Open
wants to merge 47 commits into
base: main
Choose a base branch
from

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Nov 27, 2023

This is an attempt to implement speculative decoding (paper) in vllm. It is not optimized, not tested (please avoid using it for now). The current design:

  1. Use huggingface instead of paged attention for the draft model.
  2. Does not keep kv cache for the draft model.
  3. Does not support tensor parallelism.
  4. Use kernel from prefix cache for token verification.

Copy link
Contributor

@pian13131 pian13131 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some tiny comments.

@@ -408,11 +416,24 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
if last_child_sample.accepted_tokens:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if FLAGS.ENABLE_SD:?

vllm/engine/llm_engine.py Show resolved Hide resolved
vllm/engine/llm_engine.py Show resolved Hide resolved
Comment on lines +21 to +22
self.propose_cnt = config.propose_cnt
self.draft_model_config = config.draft_model_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.config = config?


# propose draft tokens
# the function will run the draft model and set draft_tokens and draft_token_probs of each seq
def set_draft_tokens(self, seq_group_list: List[SequenceGroupMetadata],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

propose() might be a better name

)
if FLAGS.ENABLE_SD:
output = _multi_query_cached_kv_attention(
query, key, value, key_cache, value_cache, input_metadata)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need to pass key and value? I think the two vars already have copied to key_cache and value_cache by cache_ops.reshape_and_cache(). Maybe I am missing something?

Copy link

@void-main void-main left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, congrats!

for seq_group_metadata in seq_group_metadata_list:
assert len(
seq_group_metadata.seq_data
) == 1, f"Speculative Decoding does nor beam search for now: {len(seq_group_metadata.seq_data)}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little typo in the assert message

@@ -573,6 +594,11 @@ def step(self) -> List[RequestOutput]:
if scheduler_outputs.is_empty():
return ignored

# only enable speculative decoding for generation run
if self.spec_dec_worker and (not scheduler_outputs.prompt_run):
self.spec_dec_worker.set_draft_tokens(seq_group_metadata_list,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in multi GPU inference scenario, will this method be called by all the workers?

do you think it's a better idea to only run on rank 0, and broadcast the tokens to other ranks?

logger.setLevel("WARNING")


class SpecDecWorker(Worker):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This worker is too tightly coupled with assisted decoding.

Do you think it's a good idea if we abstract an base class for SpD, and move these specific implementations to a concrete class like AssistedSpcDecWorker?

But I believe we could refactor this later.

@@ -69,7 +69,7 @@ def __init__(
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
gpu_memory_utilization: float = 0.8,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a little hacky to me. what if the sequence is long and could take more than 0.2 gpu memory?

do you think it's a better idea if we actual run the assisted model in profile_num_available_blocks?

pass


if triton.__version__ >= "2.1.0":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we assert the version should be greater or equal to 2.1.0?

offs_d[:, None] // x) * stride_k_cache_d + (
(start_n + offs_n[None, :]) %
block_size) * stride_k_cache_bl + (
offs_d[:, None] % x) * stride_k_cache_x

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good job! this would be faster than my version! 👍

block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)

for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what's special for K and V of the draft tokens, why we need process these tokens separately?

self.scale,
self.alibi_slopes,
)
if FLAGS.ENABLE_SD:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct me if I'm wrong, but for assisted decoding, usually the propse_cnt is small (maybe around 4?), which would cause first dimension of q to be small, thus the q@k gemm and qk@v gemm are small. for such cases, does it really worth using Tensor Core for GEMM?

@@ -573,6 +594,11 @@ def step(self) -> List[RequestOutput]:
if scheduler_outputs.is_empty():
return ignored

# only enable speculative decoding for generation run
if self.spec_dec_worker and (not scheduler_outputs.prompt_run):
Copy link

@void-main void-main Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

report a bug here, when we start vllm api server with python3 -m vllm.entrypoints.api_server --model=/path/to/tgt_model/ --draft-model=/path/to/draft/model/ --propose-cnt=5, the server errors out. looks like you forgot set_draft_tokens and accept_tokens in AsyncLLMEngine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants