diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 00d93e1ba0b5..945276376d66 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): # Set small draft model len to force doesn't-fit-in-drafter case. spec_config_short = spec_config | {"max_model_len": 50} + test_sampling_params = [ + dict(), + dict(logprobs=2), + ] + # test_preemption, executor, async_scheduling, # spec_config, test_prefill_chunking test_configs = [ @@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): (True, "uni", True, spec_config_short, True), ] - run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) + run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params) @dynamo_config.patch(cache_size_limit=16) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a7ec0de37263..9e1283266c7c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1088,8 +1088,6 @@ def update_from_output( and request.sampling_params.logprobs is not None and logprobs ): - # NOTE: once we support N tokens per step (spec decode), - # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) if new_token_ids and self.structured_output_manager.should_advance(request): diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 926305d25f56..ccaf07e18c46 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence from dataclasses import replace import torch @@ -204,7 +205,9 @@ def _get_logprobs_tensors( def parse_output( output_token_ids: torch.Tensor, vocab_size: int, - ) -> list[list[int]]: + discard_req_indices: Sequence[int] = (), + return_cu_num_tokens: bool = False, + ) -> tuple[list[list[int]], list[int] | None]: """Parse the output of the rejection sampler. Args: output_token_ids: The sampled token IDs in shape @@ -212,6 +215,8 @@ def parse_output( replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler and will be filtered out in this function. vocab_size: The size of the vocabulary. + discard_req_indices: Optional row indices to discard tokens in. + return_cu_num_tokens: Whether to also return cumulative token counts. Returns: A list of lists of token IDs. """ @@ -220,10 +225,15 @@ def parse_output( valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( output_token_ids_np < vocab_size ) + cu_num_tokens = None + if return_cu_num_tokens: + cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist() + if len(discard_req_indices) > 0: + valid_mask[discard_req_indices] = False outputs = [ row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] - return outputs + return outputs, cu_num_tokens def apply_logits_processors( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 979f97758703..ee3080994e33 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -178,7 +178,7 @@ def __init__( self, model_runner_output: ModelRunnerOutput, sampled_token_ids: torch.Tensor, - logprobs_tensors: torch.Tensor | None, + logprobs_tensors: LogprobsTensors | None, invalid_req_indices: list[int], async_output_copy_stream: torch.cuda.Stream, vocab_size: int, @@ -214,28 +214,29 @@ def get_output(self) -> ModelRunnerOutput: This function blocks until the copy is finished. """ + max_gen_len = self.sampled_token_ids_cpu.shape[-1] self.async_copy_ready_event.synchronize() # Release the device tensors once the copy has completed. del self._logprobs_tensors del self._sampled_token_ids - max_gen_len = self.sampled_token_ids_cpu.shape[-1] if max_gen_len == 1: valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + cu_num_tokens = None else: - valid_sampled_token_ids = RejectionSampler.parse_output( + valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( self.sampled_token_ids_cpu, self.vocab_size, + self._invalid_req_indices, + return_cu_num_tokens=self._logprobs_tensors_cpu is not None, ) - for i in self._invalid_req_indices: - valid_sampled_token_ids[i].clear() output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids if self._logprobs_tensors_cpu: - # NOTE(nick): this will need to be updated to use cu_num_accepted_tokens - # for async sched + spec decode + logprobs compatibility. - output.logprobs = self._logprobs_tensors_cpu.tolists() + output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens) return output @@ -2468,28 +2469,24 @@ def _bookkeeping_sync( sampled_token_ids = sampler_output.sampled_token_ids logprobs_tensors = sampler_output.logprobs_tensors invalid_req_indices = [] - cu_num_new_tokens: list[int] | None = None + cu_num_tokens: list[int] | None = None if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = self._to_list(sampled_token_ids) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() else: # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( + valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( sampled_token_ids, self.input_batch.vocab_size, + discard_sampled_tokens_req_indices, + return_cu_num_tokens=logprobs_tensors is not None, ) - if logprobs_tensors: - # Needed for extracting logprobs when spec decoding. - # This must be done prior to discarding sampled tokens. - cu_num_new_tokens = [0] - for toks in valid_sampled_token_ids: - cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks)) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() @@ -2543,7 +2540,7 @@ def _bookkeeping_sync( req_state.output_token_ids.extend(sampled_ids) logprobs_lists = ( - logprobs_tensors.tolists(cu_num_new_tokens) + logprobs_tensors.tolists(cu_num_tokens) if not self.use_async_scheduling and logprobs_tensors is not None else None )