Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. #3951

Merged
merged 187 commits into from
Apr 23, 2024

Conversation

cadedaniel
Copy link
Collaborator

@cadedaniel cadedaniel commented Apr 9, 2024

This PR adds e2e correctness tests for speculative decoding. It is PR 7/9 in the speculative decoding open sourcing plan.

The E2E correctness tests verify that the generated output of a sequence with speculative decoding is equal to the generated output without speculative decoding when temperature is 0. We test various batch sizes, models, speculative lens, block sizes, num_gpu_blocks (& preemption), and max_model_lens (& skipping speculation for some/all sequences) and verify that this core greedy equality property holds.

See test_correctness.py for more details on test methodology.

Bugfixes

To make the tests pass, this PR introduces several fixes that are listed in order of notoriety:

  • The vLLM sampler now modifies the probability distributions so that the sampling method is encoded within the distribution. This is gated by a flag for speculative decoding. What this means is that the token sampled via greedy sampling has its probability set to 1.0. This allows speculative decoding's rejection sampler to guarantee output equality.
  • Batch expansion was incorrectly scoring tokens when some sequences were skipped. This was fixed.
  • When a "bonus token" is emitted, its KV is never generated by the draft worker. This causes accuracy reduction in proposers which use KV, e.g. draft models. This PR disables bonus tokens, with a followup issue to re-enable them. --> [Speculative decoding] [Performance]: Re-enable bonus tokens #4212
  • Incorrect system efficiency calculation fixed.

Minor feature additions

The following features were added:

  • The vLLM sampler now can return sampling results as on-GPU tensors, instead of only Python datastructures. This allows the rejection sampler to consume GPU tensors instead of serializing to CPU/serializing back to GPU.
  • Spec decode metrics are emitted by the vLLM engine using the stats object. This was required for correctness testing that spec decode with the same model for draft/target has 100% acceptance rate.
  • The draft model now has a configurable max_model_len for use in testing. This was required to test preemption.

@cadedaniel cadedaniel changed the title [Draft] [Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. [Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. Apr 22, 2024
@cadedaniel cadedaniel marked this pull request as ready for review April 22, 2024 07:15
@cadedaniel cadedaniel enabled auto-merge (squash) April 22, 2024 17:37
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts before discussing this PR. Skip sampler & tests.

speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add a check to make sure speculative_max_model_len < min( draft_max_model_len, target_max_model_len, ) in case user sets speculative_max_model_len inappropriately?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Cade to fix

# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. A bit concern here:
    From this, it seems that we assume DECODE stage only has 1 new token?
  2. I assume we cannot have chunked prefill and speculative decoding cannot be turned on the same time? Did we explicitly check or document that somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • (Cade fill out answer)
  • Cade to verify args and raise if chunked prefill enabled while spec decode enabled

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answer for future readers:

  • Chunked prefill and speculative decoding from a systems perspective are compatible, however the current vLLM implementations need work to be enabled together. I'll add a validation check which raises if both are enabled.
  • The DECODE stage currently only reports 1 new token. This is used by the scheduler to prevent a batch from becoming compute-bound using the token budget. When chunked prefill is enabled, we will need to adjust this to take into account the "new tokens" computed during speculative verification and modify this value. When chunked prefill is disabled, the new token budget is max_num_batched_tokens, and we are OK with the fact that the budget system doesn't take speculative decoding into account.
  • I'll make an issue for integrating chunked prefill with spec decode soon!

# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • (Cade fill out answer)
  • Cade to verify args and raise if chunked prefill enabled while spec decode enabled

speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Cade to fix

@@ -680,12 +760,36 @@ def _get_logprobs(
return result_prompt_logprobs, result_sample_logprobs


def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cade to list how this fits into sampler overall

  • Why not use very small temperature instead?

Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed me concerns after discussion, please add some doc to clarify, thanks!

@cadedaniel
Copy link
Collaborator Author

Applied feedback @LiuXiaoxuanPKU . I will enable auto-merge with your approval; if you have any more comments happy to take them in a future PR.

@cadedaniel cadedaniel enabled auto-merge (squash) April 22, 2024 21:25
@cadedaniel
Copy link
Collaborator Author

main branch was broken, merging again to get #4271

@cadedaniel cadedaniel merged commit 62b8aeb into vllm-project:main Apr 23, 2024
47 checks passed
xjpang pushed a commit to xjpang/vllm that referenced this pull request Apr 25, 2024
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 26, 2024
alexeykondrat pushed a commit to alexeykondrat/ci-vllm that referenced this pull request May 1, 2024
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants