From 8403a61ccbb82d0fcb36bf66b926ef39a2fd2f48 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Fri, 7 Nov 2025 10:08:14 -0800 Subject: [PATCH 01/13] Replace list[list[int]] with list[np.ndarray] Signed-off-by: Jialin Ouyang --- vllm/v1/core/sched/scheduler.py | 4 ++-- vllm/v1/outputs.py | 3 ++- vllm/v1/sample/rejection_sampler.py | 8 +++----- vllm/v1/spec_decode/eagle.py | 7 +++---- vllm/v1/spec_decode/ngram_proposer.py | 4 ++-- vllm/v1/spec_decode/suffix_decoding.py | 18 ++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 25 +++++++++++++++---------- 7 files changed, 37 insertions(+), 32 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c17b19b58c97..b77ae84f89aa 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -942,8 +942,8 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = ( - sampled_token_ids[req_index] if sampled_token_ids else [] + generated_token_ids: list[int] = ( + sampled_token_ids[req_index].tolist() if sampled_token_ids else [] ) scheduled_spec_token_ids = ( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index b5cba96e1026..f7d0d51e3102 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, NamedTuple +import numpy as np import torch if TYPE_CHECKING: @@ -148,7 +149,7 @@ class ModelRunnerOutput: # num_generated_tokens is the number of tokens # generated in the current step. It can be different for # each request due to speculative/jump decoding. - sampled_token_ids: list[list[int]] + sampled_token_ids: list[np.ndarray] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 926305d25f56..f31a0cddda9a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -3,6 +3,7 @@ from dataclasses import replace +import numpy as np import torch import torch.nn as nn @@ -204,7 +205,7 @@ def _get_logprobs_tensors( def parse_output( output_token_ids: torch.Tensor, vocab_size: int, - ) -> list[list[int]]: + ) -> list[np.ndarray]: """Parse the output of the rejection sampler. Args: output_token_ids: The sampled token IDs in shape @@ -220,10 +221,7 @@ def parse_output( valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( output_token_ids_np < vocab_size ) - outputs = [ - row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) - ] - return outputs + return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)] def apply_logits_processors( self, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 75a4140fd655..0e2e101a799a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -482,7 +482,7 @@ def propose( def prepare_next_token_ids_cpu( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, num_scheduled_tokens: dict[str, int], @@ -497,7 +497,7 @@ def prepare_next_token_ids_cpu( req_ids = gpu_input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): - if token_ids: + if token_ids.shape[0] > 0: # Common case. next_token_id = token_ids[-1] else: @@ -508,10 +508,9 @@ def prepare_next_token_ids_cpu( seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - next_token_ids = torch.tensor( + return torch.tensor( next_token_ids, dtype=torch.int32, device=self.input_ids.device ) - return next_token_ids def prepare_next_token_ids_padded( self, diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index e2f83cb24aa9..f0f870684b6c 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -131,7 +131,7 @@ def batch_propose( def propose( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], req_ids: list[str], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, @@ -140,7 +140,7 @@ def propose( # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) + num_sampled_ids = sampled_ids.shape[0] if not num_sampled_ids: # Skip speculative decoding. continue diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index 049e335db325..85f0c632a9af 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np + from vllm.config import VllmConfig from vllm.v1.worker.gpu_input_batch import InputBatch @@ -32,31 +34,31 @@ def __init__(self, vllm_config: VllmConfig): def propose( self, input_batch: InputBatch, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: + sampled_token_ids: list[np.ndarray], + ) -> list[np.ndarray]: """ Propose speculative tokens for each request in the input batch. Suffix Decoding will speculate a dynamic number of tokens for each request every decoding step, so each entry in the returned list may have different lengths. """ - draft_token_ids: list[list[int]] = [] + draft_token_ids: list[np.ndarray] = [] for i, sampled_ids in enumerate(sampled_token_ids): if not sampled_ids: # Skip speculative decoding for partial prefills. - draft_token_ids.append([]) + draft_token_ids.append(np.array([])) continue # Skip requests that require sampling parameters that are not # supported with speculative decoding. req_id = input_batch.req_ids[i] if req_id in input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) + draft_token_ids.append(np.array([])) continue num_tokens = input_batch.num_tokens_no_spec[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. - draft_token_ids.append([]) + draft_token_ids.append(np.array([])) continue index = input_batch.req_id_to_index[req_id] @@ -70,7 +72,7 @@ def propose( self.suffix_cache.start_request(req_id, prompt_token_ids) # Append the newly sampled ids to the suffix cache for this request. - self.suffix_cache.add_active_response(req_id, sampled_ids) + self.suffix_cache.add_active_response(req_id, sampled_ids.tolist()) # Suffix decoding only uses the most recent tokens up to max_tree_depth, so # we extract the pattern from the end of the input. @@ -86,7 +88,7 @@ def propose( min_token_prob=self.min_token_prob, ) - draft_token_ids.append(draft.token_ids) + draft_token_ids.append(np.array(draft.token_ids)) # Stop requests that were not seen in the input batch. for req_id in ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 26007d29d61b..802d4ea3b47f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2344,7 +2344,7 @@ def _bookkeeping_sync( ) -> tuple[ dict[str, int], LogprobsLists | None, - list[list[int]], + list[np.ndarray], dict[str, LogprobsTensors | None], list[str], dict[str, int], @@ -2370,6 +2370,7 @@ def _bookkeeping_sync( num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids invalid_req_indices = [] + valid_sampled_token_ids: list[np.ndarray] if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] @@ -2384,7 +2385,7 @@ def _bookkeeping_sync( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() + valid_sampled_token_ids[int(i)] = np.array((0,)) else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() @@ -2413,11 +2414,13 @@ def _bookkeeping_sync( ) for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None + sampled_ids = ( + np.array([-1]) if req_idx not in invalid_req_indices_set else None + ) else: sampled_ids = valid_sampled_token_ids[req_idx] - num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 + num_sampled_ids: int = sampled_ids.shape[0] if sampled_ids else 0 if cu_num_accepted_tokens is not None: cu_num_accepted_tokens.append( @@ -2759,7 +2762,9 @@ def sample_tokens( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) - def propose_draft_token_ids(sampled_token_ids): + def propose_draft_token_ids( + sampled_token_ids: torch.Tensor | list[np.ndarray], + ) -> None: assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("Draft"): self._draft_token_ids = self.propose_draft_token_ids( @@ -2874,14 +2879,14 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: torch.Tensor | list[list[int]], + sampled_token_ids: torch.Tensor | list[np.ndarray], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, - ) -> list[list[int]] | torch.Tensor: + ) -> torch.Tensor | list[np.ndarray]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(sampled_token_ids, list) @@ -2913,7 +2918,7 @@ def propose_draft_token_ids( for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids ): - indices.append(offset + len(tokens) - 1) + indices.append(offset + tokens.shape[0] - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] @@ -4841,7 +4846,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. # `tolist` would trigger a cuda wise stream sync, which @@ -4854,4 +4859,4 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() - return pinned.tolist() + return [row for row in pinned] From 597e95cc453ded5292f75a4507bd654247c3b6fe Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Fri, 7 Nov 2025 15:57:54 -0800 Subject: [PATCH 02/13] DEBUG LOGGING Signed-off-by: Jialin Ouyang --- vllm/v1/core/sched/scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b77ae84f89aa..a429bf98b57e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -898,6 +898,9 @@ def update_from_output( model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids + logger.info(f"===Jialin {len(model_runner_output.sampled_token_ids)=}") + if len(model_runner_output.sampled_token_ids) > 0: + logger.info(f"===Jialin {type(model_runner_output.sampled_token_ids[0])=}") logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -943,7 +946,7 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids: list[int] = ( - sampled_token_ids[req_index].tolist() if sampled_token_ids else [] + sampled_token_ids[req_index] if sampled_token_ids else [] ) scheduled_spec_token_ids = ( From 90b10039e2e72f9050365149bfaf4d9ce86e8733 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Fri, 7 Nov 2025 17:24:02 -0800 Subject: [PATCH 03/13] apply same changes to async scheduling Signed-off-by: Jialin Ouyang --- vllm/v1/core/sched/scheduler.py | 5 +---- vllm/v1/worker/gpu_model_runner.py | 8 +++++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a429bf98b57e..b77ae84f89aa 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -898,9 +898,6 @@ def update_from_output( model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids - logger.info(f"===Jialin {len(model_runner_output.sampled_token_ids)=}") - if len(model_runner_output.sampled_token_ids) > 0: - logger.info(f"===Jialin {type(model_runner_output.sampled_token_ids[0])=}") logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -946,7 +943,7 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids: list[int] = ( - sampled_token_ids[req_index] if sampled_token_ids else [] + sampled_token_ids[req_index].tolist() if sampled_token_ids else [] ) scheduled_spec_token_ids = ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 802d4ea3b47f..5edc565e2cce 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -212,9 +212,11 @@ def get_output(self) -> ModelRunnerOutput: del self._logprobs_tensors del self._sampled_token_ids - valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() + valid_sampled_token_ids: list[np.ndarray] = [ + row for row in self.sampled_token_ids_cpu.numpy() + ] for i in self._invalid_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[i] = np.array([]) output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids @@ -4859,4 +4861,4 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() - return [row for row in pinned] + return [row for row in pinned.numpy()] From 712bc8140d9959e5525f7fd08db178fcf393d22a Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Mon, 10 Nov 2025 06:45:04 -0800 Subject: [PATCH 04/13] fix bookkeeping errors Signed-off-by: Jialin Ouyang --- vllm/v1/spec_decode/ngram_proposer.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index f0f870684b6c..b83946258522 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -54,7 +54,7 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( - [[]] * 1024, + [np.empty((0,))] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5edc565e2cce..d8e40b6fd344 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2415,6 +2415,7 @@ def _bookkeeping_sync( [0] if spec_decode_metadata and logprobs_tensors else None ) for req_idx in range(num_sampled_tokens): + sampled_ids: np.ndarray | None if self.use_async_scheduling: sampled_ids = ( np.array([-1]) if req_idx not in invalid_req_indices_set else None @@ -2422,14 +2423,16 @@ def _bookkeeping_sync( else: sampled_ids = valid_sampled_token_ids[req_idx] - num_sampled_ids: int = sampled_ids.shape[0] if sampled_ids else 0 + num_sampled_ids: int = ( + sampled_ids.shape[0] if sampled_ids is not None else 0 + ) if cu_num_accepted_tokens is not None: cu_num_accepted_tokens.append( cu_num_accepted_tokens[-1] + num_sampled_ids ) - if not sampled_ids: + if sampled_ids is None or num_sampled_ids == 0: continue start_idx = self.input_batch.num_tokens_no_spec[req_idx] From 65658186f568acca685c791846e6a9c63be36c09 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Mon, 10 Nov 2025 06:47:33 -0800 Subject: [PATCH 05/13] Fix scheduler tests Signed-off-by: Jialin Ouyang --- tests/v1/core/test_async_scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index e0645ed43015..1d80ee987591 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import deque +import numpy as np import pytest from vllm.v1.core.sched.output import SchedulerOutput @@ -21,7 +22,7 @@ def _make_model_runner_output( return ModelRunnerOutput( req_ids=req_ids, req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, - sampled_token_ids=[[i] for i in range(len(req_ids))], + sampled_token_ids=[np.array([i]) for i in range(len(req_ids))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], From 707e3b62fbca58474fc0401f7b1ab082fd1dc52b Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Mon, 10 Nov 2025 13:23:11 -0800 Subject: [PATCH 06/13] Fix all scheduler unit tests Signed-off-by: Jialin Ouyang --- tests/v1/core/test_scheduler.py | 76 +++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 749cf7dc8397..6cd89389ff12 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3,6 +3,7 @@ import dataclasses from unittest.mock import Mock +import numpy as np import pytest import torch @@ -165,7 +166,7 @@ def test_schedule_partial_requests(): req_id_to_index=req_to_index, # Only the first request has a sampled token id because # the rest requests are still being prefilled. - sampled_token_ids=[[0], [], []], + sampled_token_ids=[np.array([0]), np.array([]), np.array([])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -212,7 +213,7 @@ def test_no_mm_input_chunking(): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[] for _ in range(len(requests))], + sampled_token_ids=[np.array([]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -272,7 +273,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[] for _ in range(len(requests))], + sampled_token_ids=[np.array([]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -296,7 +297,8 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], + sampled_token_ids=[np.array([0]), np.array([0])] + + [np.array([]) for _ in range(len(requests) - 2)], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -343,8 +345,8 @@ def test_stop_via_update_from_output(): req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[ - [EOS_TOKEN_ID], - [10, 11], + np.array([EOS_TOKEN_ID]), + np.array([10, 11]), ], # First request hits EOS, second continues logprobs=None, prompt_logprobs_dict={}, @@ -388,7 +390,10 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token + sampled_token_ids=[ + np.array([10, 42, 12]), + np.array([13, 14]), + ], # First request hits stop token logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -432,7 +437,10 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens + sampled_token_ids=[ + np.array([10, 11, 12]), + np.array([13]), + ], # First request exceeds max_tokens logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -471,7 +479,7 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -612,7 +620,7 @@ def test_schedule_concurrent_batches( model_runner_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -629,7 +637,7 @@ def test_schedule_concurrent_batches( model_runner_output = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -666,7 +674,7 @@ def test_preempt_during_execution(): model_runner_output0 = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -683,7 +691,7 @@ def test_preempt_during_execution(): model_runner_output1 = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[42]], + sampled_token_ids=[np.array([42])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -700,14 +708,18 @@ def test_preempt_during_execution(): @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ - ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match - ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences - ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence - ([[]], [[5]], (0, 0, 0, [0])), # empty sequence + ([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match + ([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch + ( + [[1, 2], [3]], + [np.array([1, 2, 5]), np.array([3, 4])], + (2, 3, 3, [2, 1]), + ), # multiple sequences + ([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence + ([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence ( [[1, 2, 3], [4, 5, 6]], - [[1, 2, 7], [4, 8]], + [np.array([1, 2, 7]), np.array([4, 8])], (2, 6, 3, [2, 1, 0]), ), # multiple mismatches ], @@ -741,7 +753,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[0] for _ in range(len(requests))], + sampled_token_ids=[np.array([0]) for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -924,7 +936,7 @@ def test_kv_connector_basic(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -974,7 +986,7 @@ def test_kv_connector_basic(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1044,7 +1056,7 @@ def test_external_prefix_cache_metrics(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[r.request_id for r in requests], req_id_to_index={r.request_id: i for i, r in enumerate(requests)}, - sampled_token_ids=[[1000]] * NUM_REQUESTS, + sampled_token_ids=[np.array([1000])] * NUM_REQUESTS, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1109,7 +1121,7 @@ def test_kv_connector_unable_to_allocate(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1194,7 +1206,7 @@ def test_kv_connector_handles_preemption(): MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_to_index, - sampled_token_ids=[[1000]] * len(req_ids), + sampled_token_ids=[np.array([1000])] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1287,7 +1299,7 @@ def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)}, - sampled_token_ids=[[1000]] * len(scheduler.running), + sampled_token_ids=[np.array([1000])] * len(scheduler.running), logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1637,7 +1649,7 @@ def test_priority_scheduling_preemption(): req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, - sampled_token_ids=[[100] for _ in low_priority_requests], + sampled_token_ids=[np.array([100]) for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1706,7 +1718,7 @@ def test_priority_scheduling_no_preemption_when_space_available(): req_id_to_index={ req.request_id: i for i, req in enumerate(low_priority_requests) }, - sampled_token_ids=[[100] for _ in low_priority_requests], + sampled_token_ids=[np.array([100]) for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1952,7 +1964,7 @@ def test_priority_scheduling_heap_property(): model_output = ModelRunnerOutput( req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, - sampled_token_ids=[[100]], + sampled_token_ids=[np.array([100])], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -2030,7 +2042,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, - sampled_token_ids=[[100]], + sampled_token_ids=[np.array([100])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2061,7 +2073,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[100] for _ in requests], + sampled_token_ids=[np.array([100]) for _ in requests], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -2087,7 +2099,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, - sampled_token_ids=[[], [100]], + sampled_token_ids=[np.array([]), np.array([100])], # spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, From 5dafc452ec078b4a3e3abbac68f784c913e25281 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Mon, 10 Nov 2025 13:41:02 -0800 Subject: [PATCH 07/13] Use np.array() as much as possible for better readability Signed-off-by: Jialin Ouyang --- vllm/v1/spec_decode/ngram_proposer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index b83946258522..378937dba988 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -54,7 +54,7 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( - [np.empty((0,))] * 1024, + [np.array([])] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), From bb5c46ed8280b9c5feb580c0067c6bcea971f2b6 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Mon, 10 Nov 2025 19:11:18 -0800 Subject: [PATCH 08/13] Fix wrong empty array set Signed-off-by: Jialin Ouyang --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d8e40b6fd344..07344a60cba6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2387,7 +2387,7 @@ def _bookkeeping_sync( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)] = np.array((0,)) + valid_sampled_token_ids[int(i)] = np.array([]) else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() From d306075bf1436c5e2b1c48f8ceba9c85934542a5 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Mon, 10 Nov 2025 19:53:15 -0800 Subject: [PATCH 09/13] Fix kv_connector.util sampled ids types Signed-off-by: Jialin Ouyang --- tests/v1/kv_connector/unit/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index f0031643aa9d..aeee7e834ad5 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -6,6 +6,7 @@ from itertools import count from typing import Any +import numpy as np import torch from vllm import SamplingParams @@ -222,7 +223,7 @@ def create_model_runner_output( # Make sampled tokens. sampled_token = EOS_TOKEN_ID if use_eos else token_id - sampled_token_ids = [[sampled_token] for _ in req_ids] + sampled_token_ids = [np.array([sampled_token]) for _ in req_ids] kv_connector_output = ( None From d08f121d1e997ec6247cb9530d7f780853822bb8 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Tue, 11 Nov 2025 10:31:26 -0800 Subject: [PATCH 10/13] Fix types for spec decoding Signed-off-by: Jialin Ouyang --- tests/v1/spec_decode/test_eagle.py | 5 ++++- vllm/v1/core/sched/scheduler.py | 3 ++- vllm/v1/spec_decode/suffix_decoding.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 47d05a20a65d..bcce44dd9d1b 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -3,6 +3,7 @@ from unittest import mock +import numpy as np import pytest import torch @@ -112,7 +113,9 @@ def test_prepare_next_token_ids(): sampled_token_ids_tensor = torch.tensor( sampled_token_ids, dtype=torch.int32, device=device ) - sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] + sampled_token_ids_cpu = [ + np.array([i for i in seq if i != -1]) for seq in sampled_token_ids + ] expected_next_token_ids_cpu = [1, 4, 30, 40] expected_next_token_ids_tensor = torch.tensor( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b77ae84f89aa..a60a6913d015 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1142,10 +1142,11 @@ def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: - for req_id, spec_token_ids in zip( + for req_id, spec_token_ids_np in zip( draft_token_ids.req_ids, draft_token_ids.draft_token_ids, ): + spec_token_ids = spec_token_ids_np.tolist() request = self.requests.get(req_id) if request is None or request.is_finished(): # The request may have been finished. Skip. diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index 85f0c632a9af..11cee6366b38 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -43,7 +43,7 @@ def propose( """ draft_token_ids: list[np.ndarray] = [] for i, sampled_ids in enumerate(sampled_token_ids): - if not sampled_ids: + if sampled_ids.shape[0] == 0: # Skip speculative decoding for partial prefills. draft_token_ids.append(np.array([])) continue From 216cbee5cd5f1f163d80436959808314b353c3ad Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Tue, 11 Nov 2025 10:54:03 -0800 Subject: [PATCH 11/13] Minimize change continue using list[list[int]] in DraftTokenIds Signed-off-by: Jialin Ouyang --- vllm/v1/core/sched/scheduler.py | 3 +-- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a60a6913d015..b77ae84f89aa 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1142,11 +1142,10 @@ def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: - for req_id, spec_token_ids_np in zip( + for req_id, spec_token_ids in zip( draft_token_ids.req_ids, draft_token_ids.draft_token_ids, ): - spec_token_ids = spec_token_ids_np.tolist() request = self.requests.get(req_id) if request is None or request.is_finished(): # The request may have been finished. Skip. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 07344a60cba6..d56ea4f65346 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -530,7 +530,7 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + self._draft_token_ids: list[np.ndarray] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_num_reqs, 1), @@ -2877,7 +2877,7 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: if isinstance(self._draft_token_ids, torch.Tensor): draft_token_ids = self._draft_token_ids.tolist() else: - draft_token_ids = self._draft_token_ids + draft_token_ids = [row.tolist() for row in self._draft_token_ids] self._draft_token_ids = None return DraftTokenIds(req_ids, draft_token_ids) From f8d3ca5908de75fbffbb5d01bef7ddb6dc5f0e6a Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Tue, 11 Nov 2025 14:08:40 -0800 Subject: [PATCH 12/13] Revert more type changes for draft tokens Signed-off-by: Jialin Ouyang --- vllm/v1/spec_decode/suffix_decoding.py | 10 +++++----- vllm/v1/worker/gpu_model_runner.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index 11cee6366b38..d76e0ffe778d 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -35,7 +35,7 @@ def propose( self, input_batch: InputBatch, sampled_token_ids: list[np.ndarray], - ) -> list[np.ndarray]: + ) -> list[list[int]]: """ Propose speculative tokens for each request in the input batch. Suffix Decoding will speculate a dynamic number of tokens for each request every decoding step, @@ -45,20 +45,20 @@ def propose( for i, sampled_ids in enumerate(sampled_token_ids): if sampled_ids.shape[0] == 0: # Skip speculative decoding for partial prefills. - draft_token_ids.append(np.array([])) + draft_token_ids.append([]) continue # Skip requests that require sampling parameters that are not # supported with speculative decoding. req_id = input_batch.req_ids[i] if req_id in input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append(np.array([])) + draft_token_ids.append([]) continue num_tokens = input_batch.num_tokens_no_spec[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. - draft_token_ids.append(np.array([])) + draft_token_ids.append([]) continue index = input_batch.req_id_to_index[req_id] @@ -88,7 +88,7 @@ def propose( min_token_prob=self.min_token_prob, ) - draft_token_ids.append(np.array(draft.token_ids)) + draft_token_ids.append(draft.token_ids) # Stop requests that were not seen in the input batch. for req_id in ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d56ea4f65346..b4830fbd0a55 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -530,7 +530,7 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: list[np.ndarray] | torch.Tensor | None = None + self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_num_reqs, 1), @@ -2877,7 +2877,7 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: if isinstance(self._draft_token_ids, torch.Tensor): draft_token_ids = self._draft_token_ids.tolist() else: - draft_token_ids = [row.tolist() for row in self._draft_token_ids] + draft_token_ids = self._draft_token_ids self._draft_token_ids = None return DraftTokenIds(req_ids, draft_token_ids) @@ -2891,7 +2891,7 @@ def propose_draft_token_ids( aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, - ) -> torch.Tensor | list[np.ndarray]: + ) -> torch.Tensor | list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(sampled_token_ids, list) From 0be7d4b9fd60d04ab0aa3fa0fc5a269f7d8e3032 Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Tue, 11 Nov 2025 16:53:02 -0800 Subject: [PATCH 13/13] Fix ngram spec unit tests Signed-off-by: Jialin Ouyang --- tests/v1/spec_decode/test_ngram.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 692c39282c37..563bc1d957f4 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -77,7 +77,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match. token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -88,7 +88,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -99,7 +99,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # No match for 4-gram but match for 3-gram. token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -111,7 +111,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # In this case, the proposer should return the 4-gram match. token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -122,7 +122,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Match for 2-gram and 3-gram, but not 4-gram. token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -133,7 +133,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Multiple 3-gram matched, but always pick the first one. token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -144,7 +144,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # check empty input token_ids_cpu = np.array([[]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0]], + sampled_token_ids=[np.array([0])], req_ids=["0"], num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), token_ids_cpu=token_ids_cpu, @@ -157,7 +157,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # second request has 3 tokens and no match. Padded with -1 for max len 5 token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( - sampled_token_ids=[[0], [1]], + sampled_token_ids=[np.array([0]), np.array([1])], req_ids=["0", "1"], num_tokens_no_spec=np.array([5, 3]), token_ids_cpu=token_ids_cpu, @@ -181,7 +181,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: input_2[:3] = [4, 5, 6] token_ids_cpu = np.array([input_1, input_2]) result = ngram_proposer.propose( - sampled_token_ids=[[0], [1]], + sampled_token_ids=[np.array([0]), np.array([1])], req_ids=["0", "1"], num_tokens_no_spec=np.array([len(input_1), 3]), token_ids_cpu=token_ids_cpu,