diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index e0645ed43015..1d80ee987591 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import deque +import numpy as np import pytest from vllm.v1.core.sched.output import SchedulerOutput @@ -21,7 +22,7 @@ def _make_model_runner_output( return ModelRunnerOutput( req_ids=req_ids, req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, - sampled_token_ids=[[i] for i in range(len(req_ids))], + sampled_token_ids=[np.array([i]) for i in range(len(req_ids))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 749cf7dc8397..6cd89389ff12 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3,6 +3,7 @@ import dataclasses from unittest.mock import Mock +import numpy as np import pytest import torch @@ -165,7 +166,7 @@ def test_schedule_partial_requests(): req_id_to_index=req_to_index, # Only the first request has a sampled token id because # the rest requests are still being prefilled. - sampled_token_ids=[[0], [], []], + sampled_token_ids=[np.array([0]), np.array([]), np.array([])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -212,7 +213,7 @@ def test_no_mm_input_chunking(): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[] for _ in range(len(requests))], + sampled_token_ids=[np.array([]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -272,7 +273,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[] for _ in range(len(requests))], + sampled_token_ids=[np.array([]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -296,7 +297,8 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], + sampled_token_ids=[np.array([0]), np.array([0])] + + [np.array([]) for _ in range(len(requests) - 2)], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -343,8 +345,8 @@ def test_stop_via_update_from_output(): req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[ - [EOS_TOKEN_ID], - [10, 11], + np.array([EOS_TOKEN_ID]), + np.array([10, 11]), ], # First request hits EOS, second continues logprobs=None, prompt_logprobs_dict={}, @@ -388,7 +390,10 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token + sampled_token_ids=[ + np.array([10, 42, 12]), + np.array([13, 14]), + ], # First request hits stop token logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -432,7 +437,10 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens + sampled_token_ids=[ + np.array([10, 11, 12]), + np.array([13]), + ], # First request exceeds max_tokens logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -471,7 +479,7 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -612,7 +620,7 @@ def test_schedule_concurrent_batches( model_runner_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -629,7 +637,7 @@ def test_schedule_concurrent_batches( model_runner_output = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -666,7 +674,7 @@ def test_preempt_during_execution(): model_runner_output0 = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -683,7 +691,7 @@ def test_preempt_during_execution(): model_runner_output1 = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[42]], + sampled_token_ids=[np.array([42])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -700,14 +708,18 @@ def test_preempt_during_execution(): @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ - ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match - ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences - ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence - ([[]], [[5]], (0, 0, 0, [0])), # empty sequence + ([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match + ([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch + ( + [[1, 2], [3]], + [np.array([1, 2, 5]), np.array([3, 4])], + (2, 3, 3, [2, 1]), + ), # multiple sequences + ([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence + ([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence ( [[1, 2, 3], [4, 5, 6]], - [[1, 2, 7], [4, 8]], + [np.array([1, 2, 7]), np.array([4, 8])], (2, 6, 3, [2, 1, 0]), ), # multiple mismatches ], @@ -741,7 +753,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[0] for _ in range(len(requests))], + sampled_token_ids=[np.array([0]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -924,7 +936,7 @@ def test_kv_connector_basic(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -974,7 +986,7 @@ def test_kv_connector_basic(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1044,7 +1056,7 @@ def test_external_prefix_cache_metrics(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[r.request_id for r in requests], req_id_to_index={r.request_id: i for i, r in enumerate(requests)}, - sampled_token_ids=[[1000]] * NUM_REQUESTS, + sampled_token_ids=[np.array([1000])] * NUM_REQUESTS, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1109,7 +1121,7 @@ def test_kv_connector_unable_to_allocate(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1194,7 +1206,7 @@ def test_kv_connector_handles_preemption(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1287,7 +1299,7 @@ def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)}, - sampled_token_ids=[[1000]] * len(scheduler.running), + sampled_token_ids=[np.array([1000])] * len(scheduler.running), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1637,7 +1649,7 @@ def test_priority_scheduling_preemption(): req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, - sampled_token_ids=[[100] for _ in low_priority_requests], + sampled_token_ids=[np.array([100]) for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1706,7 +1718,7 @@ def test_priority_scheduling_no_preemption_when_space_available(): req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, - sampled_token_ids=[[100] for _ in low_priority_requests], + sampled_token_ids=[np.array([100]) for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1952,7 +1964,7 @@ def test_priority_scheduling_heap_property(): model_output = ModelRunnerOutput( req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, - sampled_token_ids=[[100]], + sampled_token_ids=[np.array([100])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -2030,7 +2042,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, - sampled_token_ids=[[100]], + sampled_token_ids=[np.array([100])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2061,7 +2073,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[100] for _ in requests], + sampled_token_ids=[np.array([100]) for _ in requests], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2087,7 +2099,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[], [100]], + sampled_token_ids=[np.array([]), np.array([100])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index f0031643aa9d..aeee7e834ad5 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -6,6 +6,7 @@ from itertools import count from typing import Any +import numpy as np import torch from vllm import SamplingParams @@ -222,7 +223,7 @@ def create_model_runner_output( # Make sampled tokens. sampled_token = EOS_TOKEN_ID if use_eos else token_id - sampled_token_ids = [[sampled_token] for _ in req_ids] + sampled_token_ids = [np.array([sampled_token]) for _ in req_ids] kv_connector_output = ( None diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 47d05a20a65d..bcce44dd9d1b 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -3,6 +3,7 @@ from unittest import mock +import numpy as np import pytest import torch @@ -112,7 +113,9 @@ def test_prepare_next_token_ids(): sampled_token_ids_tensor = torch.tensor( sampled_token_ids, dtype=torch.int32, device=device ) - sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] + sampled_token_ids_cpu = [ + np.array([i for i in seq if i != -1]) for seq in sampled_token_ids + ] expected_next_token_ids_cpu = [1, 4, 30, 40] expected_next_token_ids_tensor = torch.tensor( diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 692c39282c37..563bc1d957f4 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -77,7 +77,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match. token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -88,7 +88,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -99,7 +99,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram but match for 3-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -111,7 +111,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # In this case, the proposer should return the 4-gram match. token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -122,7 +122,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Match for 2-gram and 3-gram, but not 4-gram. token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -133,7 +133,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Multiple 3-gram matched, but always pick the first one. token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -144,7 +144,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # check empty input token_ids_cpu = np.array([[]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -157,7 +157,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # second request has 3 tokens and no match. Padded with -1 for max len 5 token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0], [1]], + sampled_token_ids=[np.array([0]), np.array([1])], req_ids=["0", "1"], num_tokens_no_spec=np.array([5, 3]), token_ids_cpu=token_ids_cpu, @@ -181,7 +181,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: input_2[:3] = [4, 5, 6] token_ids_cpu = np.array([input_1, input_2]) result = ngram_proposer.propose( - sampled_token_ids=[[0], [1]], + sampled_token_ids=[np.array([0]), np.array([1])], req_ids=["0", "1"], num_tokens_no_spec=np.array([len(input_1), 3]), token_ids_cpu=token_ids_cpu, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c17b19b58c97..b77ae84f89aa 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -942,8 +942,8 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = ( - sampled_token_ids[req_index] if sampled_token_ids else [] + generated_token_ids: list[int] = ( + sampled_token_ids[req_index].tolist() if sampled_token_ids else [] ) scheduled_spec_token_ids = ( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index b5cba96e1026..f7d0d51e3102 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, NamedTuple +import numpy as np import torch if TYPE_CHECKING: @@ -148,7 +149,7 @@ class ModelRunnerOutput: # num_generated_tokens is the number of tokens # generated in the current step. It can be different for # each request due to speculative/jump decoding. - sampled_token_ids: list[list[int]] + sampled_token_ids: list[np.ndarray] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 926305d25f56..f31a0cddda9a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -3,6 +3,7 @@ from dataclasses import replace +import numpy as np import torch import torch.nn as nn @@ -204,7 +205,7 @@ def _get_logprobs_tensors( def parse_output( output_token_ids: torch.Tensor, vocab_size: int, - ) -> list[list[int]]: + ) -> list[np.ndarray]: """Parse the output of the rejection sampler. Args: output_token_ids: The sampled token IDs in shape @@ -220,10 +221,7 @@ def parse_output( valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( output_token_ids_np < vocab_size ) - outputs = [ - row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) - ] - return outputs + return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)] def apply_logits_processors( self, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 75a4140fd655..0e2e101a799a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -482,7 +482,7 @@ def propose( def prepare_next_token_ids_cpu( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, num_scheduled_tokens: dict[str, int], @@ -497,7 +497,7 @@ def prepare_next_token_ids_cpu( req_ids = gpu_input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): - if token_ids: + if token_ids.shape[0] > 0: # Common case. next_token_id = token_ids[-1] else: @@ -508,10 +508,9 @@ def prepare_next_token_ids_cpu( seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - next_token_ids = torch.tensor( + return torch.tensor( next_token_ids, dtype=torch.int32, device=self.input_ids.device ) - return next_token_ids def prepare_next_token_ids_padded( self, diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index e2f83cb24aa9..378937dba988 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -54,7 +54,7 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( - [[]] * 1024, + [np.array([])] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), @@ -131,7 +131,7 @@ def batch_propose( def propose( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], req_ids: list[str], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, @@ -140,7 +140,7 @@ def propose( # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) + num_sampled_ids = sampled_ids.shape[0] if not num_sampled_ids: # Skip speculative decoding. continue diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index 049e335db325..d76e0ffe778d 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np + from vllm.config import VllmConfig from vllm.v1.worker.gpu_input_batch import InputBatch @@ -32,16 +34,16 @@ def __init__(self, vllm_config: VllmConfig): def propose( self, input_batch: InputBatch, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], ) -> list[list[int]]: """ Propose speculative tokens for each request in the input batch. Suffix Decoding will speculate a dynamic number of tokens for each request every decoding step, so each entry in the returned list may have different lengths. """ - draft_token_ids: list[list[int]] = [] + draft_token_ids: list[np.ndarray] = [] for i, sampled_ids in enumerate(sampled_token_ids): - if not sampled_ids: + if sampled_ids.shape[0] == 0: # Skip speculative decoding for partial prefills. draft_token_ids.append([]) continue @@ -70,7 +72,7 @@ def propose( self.suffix_cache.start_request(req_id, prompt_token_ids) # Append the newly sampled ids to the suffix cache for this request. - self.suffix_cache.add_active_response(req_id, sampled_ids) + self.suffix_cache.add_active_response(req_id, sampled_ids.tolist()) # Suffix decoding only uses the most recent tokens up to max_tree_depth, so # we extract the pattern from the end of the input. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 26007d29d61b..b4830fbd0a55 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -212,9 +212,11 @@ def get_output(self) -> ModelRunnerOutput: del self._logprobs_tensors del self._sampled_token_ids - valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() + valid_sampled_token_ids: list[np.ndarray] = [ + row for row in self.sampled_token_ids_cpu.numpy() + ] for i in self._invalid_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[i] = np.array([]) output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids @@ -2344,7 +2346,7 @@ def _bookkeeping_sync( ) -> tuple[ dict[str, int], LogprobsLists | None, - list[list[int]], + list[np.ndarray], dict[str, LogprobsTensors | None], list[str], dict[str, int], @@ -2370,6 +2372,7 @@ def _bookkeeping_sync( num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids invalid_req_indices = [] + valid_sampled_token_ids: list[np.ndarray] if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] @@ -2384,7 +2387,7 @@ def _bookkeeping_sync( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() + valid_sampled_token_ids[int(i)] = np.array([]) else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() @@ -2412,19 +2415,24 @@ def _bookkeeping_sync( [0] if spec_decode_metadata and logprobs_tensors else None ) for req_idx in range(num_sampled_tokens): + sampled_ids: np.ndarray | None if self.use_async_scheduling: - sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None + sampled_ids = ( + np.array([-1]) if req_idx not in invalid_req_indices_set else None + ) else: sampled_ids = valid_sampled_token_ids[req_idx] - num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 + num_sampled_ids: int = ( + sampled_ids.shape[0] if sampled_ids is not None else 0 + ) if cu_num_accepted_tokens is not None: cu_num_accepted_tokens.append( cu_num_accepted_tokens[-1] + num_sampled_ids ) - if not sampled_ids: + if sampled_ids is None or num_sampled_ids == 0: continue start_idx = self.input_batch.num_tokens_no_spec[req_idx] @@ -2759,7 +2767,9 @@ def sample_tokens( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) - def propose_draft_token_ids(sampled_token_ids): + def propose_draft_token_ids( + sampled_token_ids: torch.Tensor | list[np.ndarray], + ) -> None: assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("Draft"): self._draft_token_ids = self.propose_draft_token_ids( @@ -2874,14 +2884,14 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: torch.Tensor | list[list[int]], + sampled_token_ids: torch.Tensor | list[np.ndarray], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, - ) -> list[list[int]] | torch.Tensor: + ) -> torch.Tensor | list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(sampled_token_ids, list) @@ -2913,7 +2923,7 @@ def propose_draft_token_ids( for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids ): - indices.append(offset + len(tokens) - 1) + indices.append(offset + tokens.shape[0] - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] @@ -4841,7 +4851,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. # `tolist` would trigger a cuda wise stream sync, which @@ -4854,4 +4864,4 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() - return pinned.tolist() + return [row for row in pinned.numpy()]