Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling #3103

Merged
merged 73 commits into from
Mar 9, 2024

Conversation

cadedaniel
Copy link
Collaborator

@cadedaniel cadedaniel commented Feb 29, 2024

This PR implements a vLLM Worker which invokes a draft worker to obtain proposals, invokes a target worker to obtain probabilities of each proposal, and then applies rejection sampling to accept/reject each speculated token. It is a part of the speculative decoding contribution by Anyscale to vLLM, see #2188 for more info.

High-level design

The high-level design is as follows:
Screenshot 2024-02-28 at 11 28 15 PM

Currently, only the "draft model" approach to speculative decoding is implemented with top-1 proposals from the draft model, and lossless rejection sampling. In the future, other proposal approaches may be added, such as Medusa/Eagle (requiring top-k proposals/tree attention scoring), Lookahead, RAG, etc. The key contribution of this PR is a light framework for proposing, scoring, and verifying speculative tokens using non-contiguous KV memory.

Notes for reviewers

  • The worker has decent unit test coverage but does not yet work end-to-end with the LLMEngine. That will come in a followup PRs 4/9 and 6/9 in the open sourcing plan.
  • Some work remains on my part to clean this up for review. Namely:
    • Port over proposal routine for the multi-step-worker (currently mocked)
    • Decouple "batch expansion" (see below) from the proposal/scoring/verification scheme. This will allow future contribution of MQA for scoring.
    • Documentation
    • General cleanup of messy bits from open sourcing

What is "batch expansion"?

This PR does not use MQA for scoring proposal tokens. Instead, it uses the single-query PagedAttention kernel (aka, normal vLLM decode attention) to perform scoring of the proposal tokens. This was done because at the time of implementation, we did not yet have performant MQA kernels for non-contiguous KV memory. We now have an abundance of these (notably, FlashAttention and FlashInfer, along with Triton implementations, e.g. in #2607). Batch expansion should be replaced by these to obtain some efficiency gain in verification time.

Screenshot 2024-02-28 at 11 44 57 PM

More details on batch expansion and the optimization opportunity can be found here.

@robertgshaw2-neuralmagic
Copy link
Collaborator

cool!

@cadedaniel cadedaniel changed the title [WIP] [Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling [Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling Mar 6, 2024
@cadedaniel cadedaniel marked this pull request as ready for review March 6, 2024 06:42
@cadedaniel
Copy link
Collaborator Author

Ready for review. cc @LiuXiaoxuanPKU @ymwangg @robertgshaw2-neuralmagic @Yard1

@cadedaniel cadedaniel requested a review from pcmoritz March 6, 2024 06:43
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! Just some minor questions & comments.

i for i, (_, proposal_len) in enumerate(
zip(seq_group_metadata_list, proposal_lens_list))
if proposal_len == 0
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we can merge the two for loops above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems there are two concerns here:

  • performance of two loops
  • readability of two loops

I'll defer performance optimization until later; I'll put these into a helper function to make it easier to read.

@@ -0,0 +1,347 @@
from typing import List, Tuple, Optional, Dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, there is some complexity by separating the ways we can spec and non_spec sequences. In the future, we will remove the complexity by introducing variable proposed length and flashInfer kernel. Maybe we can add some comments about this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will add some comment!

accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
k: int,
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding comments for almost all the functions! Really appreciate it!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄

seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()

if seq_len + max_proposal_len < self._max_model_len:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here saying
(1) we want to address the different model length between the draft and target model.
(2) the proposal_lens can only be max_proposal_len or 0 for now. It can not be length between 0 and max_proposal_len.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points, will add!


if non_spec_indices:
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
all_probs[non_spec_indices, 1:, :] = non_spec_target_probs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit confused by the 1 here, why starting from 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug! Should be :1. Saved me some headache during correctness testing, thank you 😄

Comment on lines 162 to 163
all_tokens = torch.ones(
original_bs, k + 1, device=self._device, dtype=torch.long) * -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of torch.ones * -1, do torch.full with -1 as the fill value

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added, thanks

@LiuXiaoxuanPKU LiuXiaoxuanPKU merged commit 8437bae into vllm-project:main Mar 9, 2024
24 checks passed
dtransposed pushed a commit to afeldman-nm/vllm that referenced this pull request Mar 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants