Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
16 changes: 10 additions & 6 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
import torch
Expand All @@ -13,9 +13,9 @@
from .utils import create_batch, mock_worker


@pytest.mark.parametrize('queue_size', [2, 4])
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6])
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10])
@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
"""Verify that speculative tokens are disabled when the batch size
Expand All @@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
num_lookahead_slots=k,
running_queue_size=queue_size)

with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
if queue_size > disable_by_batch_size:
with patch.object(worker,
'_run_no_spec',
side_effect=ValueError(exception_secret)), \
pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)

# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
Expand Down
17 changes: 12 additions & 5 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,17 @@ def execute_model(
self._maybe_disable_speculative_tokens(
disable_all_speculation, execute_model_req.seq_group_metadata_list)

# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
execute_model_req.seq_group_metadata_list
) == 0 or disable_all_speculation:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)

Expand Down Expand Up @@ -316,8 +323,8 @@ def _maybe_disable_speculative_tokens(
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to
the proposer and scorer model so that the KV cache is consistent
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
Expand Down