Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/v1/e2e/test_async_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
# Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50}

test_sampling_params = [
dict(),
dict(logprobs=2),
]

# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
Expand All @@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True),
]

run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)


@dynamo_config.patch(cache_size_limit=16)
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,8 +1088,6 @@ def update_from_output(
and request.sampling_params.logprobs is not None
and logprobs
):
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)

if new_token_ids and self.structured_output_manager.should_advance(request):
Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence
from dataclasses import replace

import torch
Expand Down Expand Up @@ -204,14 +205,18 @@ def _get_logprobs_tensors(
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
discard_req_indices: Sequence[int] = (),
return_cu_num_tokens: bool = False,
) -> tuple[list[list[int]], list[int] | None]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in.
return_cu_num_tokens: Whether to also return cumulative token counts.
Returns:
A list of lists of token IDs.
"""
Expand All @@ -220,10 +225,15 @@ def parse_output(
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
cu_num_tokens = None
if return_cu_num_tokens:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
if len(discard_req_indices) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this done after computing cu_num_tokens?

Copy link
Member Author

@njhill njhill Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because cu_num_tokens is used to index into the logprobs tensors that don't take the discarded indices into account, so doing it beforehand results in incorrect output.

Originally it was done before which was a bug, fixed by #29216.

valid_mask[discard_req_indices] = False
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
return outputs, cu_num_tokens

def apply_logits_processors(
self,
Expand Down
37 changes: 17 additions & 20 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor,
logprobs_tensors: torch.Tensor | None,
logprobs_tensors: LogprobsTensors | None,
invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream,
vocab_size: int,
Expand Down Expand Up @@ -214,28 +214,29 @@ def get_output(self) -> ModelRunnerOutput:

This function blocks until the copy is finished.
"""
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
self.async_copy_ready_event.synchronize()

# Release the device tensors once the copy has completed.
del self._logprobs_tensors
del self._sampled_token_ids
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
cu_num_tokens = None
else:
valid_sampled_token_ids = RejectionSampler.parse_output(
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
self.sampled_token_ids_cpu,
self.vocab_size,
self._invalid_req_indices,
return_cu_num_tokens=self._logprobs_tensors_cpu is not None,
)
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()

output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu:
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
# for async sched + spec decode + logprobs compatibility.
output.logprobs = self._logprobs_tensors_cpu.tolists()
output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
return output


Expand Down Expand Up @@ -2468,28 +2469,24 @@ def _bookkeeping_sync(
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
cu_num_new_tokens: list[int] | None = None
cu_num_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
discard_sampled_tokens_req_indices,
return_cu_num_tokens=logprobs_tensors is not None,
)
if logprobs_tensors:
# Needed for extracting logprobs when spec decoding.
# This must be done prior to discarding sampled tokens.
cu_num_new_tokens = [0]
for toks in valid_sampled_token_ids:
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
Expand Down Expand Up @@ -2543,7 +2540,7 @@ def _bookkeeping_sync(
req_state.output_token_ids.extend(sampled_ids)

logprobs_lists = (
logprobs_tensors.tolists(cu_num_new_tokens)
logprobs_tensors.tolists(cu_num_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)
Expand Down