diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 76408fba2e16..aaac2deb12ac 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -30,7 +30,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager -from vllm.v1.structured_output.request import StructuredOutputRequest from .utils import EOS_TOKEN_ID, create_requests, create_scheduler @@ -335,10 +334,10 @@ def test_stop_via_update_from_output(): requests[0].request_id: [], requests[1].request_id: [10], }, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -383,10 +382,10 @@ def test_stop_via_update_from_output(): requests[0].request_id: [10, 42], requests[1].request_id: [13], }, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -429,10 +428,10 @@ def test_stop_via_update_from_output(): requests[0].request_id: [10, 11], requests[1].request_id: [], }, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -470,10 +469,10 @@ def test_stop_via_update_from_output(): total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -1941,7 +1940,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): sampling_params=sampling_params, pooling_params=None, eos_token_id=EOS_TOKEN_ID, - structured_output_request=StructuredOutputRequest(sampling_params), ) scheduler.add_request(request) output = scheduler.schedule() diff --git a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py index 0bb67b574fa1..b5c8f378be18 100644 --- a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py +++ b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py @@ -26,7 +26,7 @@ def _make_empty_scheduler_output(): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, kv_connector_metadata=SharedStorageConnectorMetadata(), ) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index df9fcdc37fa3..e471174ef674 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner): total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids={req_id}, free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner): total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 817cd7f10c1c..fe52f565c8a8 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -146,10 +146,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -212,10 +212,10 @@ def test_update_states_request_finished(model_runner, dist_init): total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids={req_id}, free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -244,10 +244,10 @@ def test_update_states_request_resumed(model_runner, dist_init): total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -273,10 +273,10 @@ def test_update_states_request_resumed(model_runner, dist_init): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -366,10 +366,10 @@ def test_update_states_no_changes(model_runner, dist_init): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -403,10 +403,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index bce15e1a476f..619dcd178a13 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -165,9 +165,8 @@ class SchedulerOutput: # freed from the encoder cache. free_encoder_mm_hashes: list[str] - # Dict of request ids to their index within the batch - # for filling the next token bitmask - structured_output_request_ids: dict[str, int] + # ids of structured outputs requests included in the bitmask, in order. + structured_output_request_ids: list[str] # the bitmask for the whole batch grammar_bitmask: "npt.NDArray[np.int32] | None" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cbbdf48c6e0c..8282d70dc883 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -5,7 +5,7 @@ import time from collections import defaultdict from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -34,6 +34,10 @@ from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + logger = init_logger(__name__) @@ -610,11 +614,8 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - scheduled_requests = ( - scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs - ) structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( - scheduled_requests, scheduled_spec_decode_tokens + num_scheduled_tokens.keys(), scheduled_spec_decode_tokens ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, @@ -878,32 +879,28 @@ def _try_schedule_encoder_inputs( def get_grammar_bitmask( self, - requests: list[Request], + scheduled_request_ids: Iterable[str], scheduled_spec_decode_tokens: dict[str, list[int]], - ): - # NOTE: structured_output_request_ids maps - # a request's (request that uses structured output) - # request_id to its index in the batch. - # This will help us determine to slice the grammar bitmask - # and only applies valid mask for requests that - # uses structured decoding. - structured_output_request_ids: dict[str, int] = {} - for i, req in enumerate(requests): - if req.use_structured_output: - # PERF: in case of chunked prefill, - # request might not include any new tokens. - # Therefore, we might introduce some additional - # cycle to fill in the bitmask, which could be a big no-op. - structured_output_request_ids[req.request_id] = i - + ) -> tuple[list[str], "npt.NDArray[np.int32] | None"]: + # Collect list of scheduled request ids that use structured output. + # The corresponding rows of the bitmask will be in this order. + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids = [ + req_id + for req_id in scheduled_request_ids + if (req := self.requests.get(req_id)) and req.use_structured_output + ] if not structured_output_request_ids: - bitmask = None - else: - bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) + return structured_output_request_ids, None + + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) return structured_output_request_ids, bitmask def update_from_output( @@ -1011,12 +1008,10 @@ def update_from_output( new_logprobs = logprobs.slice(req_index, req_index + 1) if new_token_ids and self.structured_output_manager.should_advance(request): - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # checked above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids - ) + struct_output_request = request.structured_output_request + assert struct_output_request is not None + assert struct_output_request.grammar is not None + struct_output_request.grammar.accept_tokens(req_id, new_token_ids) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 5926bf5b46ee..864b0eb7fa41 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -40,7 +40,6 @@ def __init__( prompt_embeds: torch.Tensor | None = None, mm_features: list[MultiModalFeatureSpec] | None = None, lora_request: Optional["LoRARequest"] = None, - structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: str | None = None, priority: int = 0, trace_headers: Mapping[str, str] | None = None, @@ -54,11 +53,12 @@ def __init__( # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id self.lora_request = lora_request - self.structured_output_request = structured_output_request + self.structured_output_request = StructuredOutputRequest.from_sampling_params( + sampling_params + ) self.arrival_time = arrival_time if arrival_time is not None else time.time() self.status = RequestStatus.WAITING - self.use_structured_output = False self.events: list[EngineCoreEvent] = [] self.stop_reason: int | str | None = None @@ -72,9 +72,8 @@ def __init__( # Generative models. assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - if sampling_params.structured_outputs is not None: + if self.structured_output_request is not None: self.status = RequestStatus.WAITING_FOR_FSM - self.use_structured_output = True if sampling_params.extra_args is not None: self.kv_transfer_params = sampling_params.extra_args.get( @@ -145,11 +144,6 @@ def from_engine_core_request( eos_token_id=request.eos_token_id, arrival_time=request.arrival_time, lora_request=request.lora_request, - structured_output_request=StructuredOutputRequest( - sampling_params=request.sampling_params - ) - if request.sampling_params - else None, cache_salt=request.cache_salt, priority=request.priority, trace_headers=request.trace_headers, @@ -170,6 +164,10 @@ def append_output_token_ids( if self.get_hash_new_full_blocks is not None: self.block_hashes.extend(self.get_hash_new_full_blocks()) + @property + def use_structured_output(self) -> bool: + return self.structured_output_request is not None + @property def is_output_corrupted(self) -> bool: return self.num_nans_in_logits > 0 diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 336a0eb98682..8d7f4b5d6896 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -167,7 +167,7 @@ def _async_submit_fill_bitmask( def grammar_bitmask( self, requests: dict[str, Request], - structured_output_request_ids: dict[str, int], + structured_output_request_ids: list[str], scheduled_spec_decode_tokens: dict[str, list[int]], ) -> "npt.NDArray[np.int32] | None": # Prepare the structured output bitmask for this batch. @@ -196,17 +196,16 @@ def grammar_bitmask( # masks for each request, one for each possible bonus token position. # These are stored inline in the tensor and unpacked by the gpu runner. cumulative_index = 0 - ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) # Optimized parallel filling of bitmasks for # non-spec, large-batch-size cases if ( - len(ordered_seq) > self.fill_bitmask_parallel_threshold + len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold and max_num_spec_tokens == 0 ): promises = [] batch = [] - for req_id, _ in ordered_seq: + for req_id in structured_output_request_ids: request = requests[req_id] structured_output_request = request.structured_output_request if TYPE_CHECKING: @@ -230,7 +229,7 @@ def grammar_bitmask( promise.result() else: # Fallback to serial filling of bitmasks for small-batch-size cases - for req_id, _ in ordered_seq: + for req_id in structured_output_request_ids: request = requests[req_id] structured_output_request = request.structured_output_request @@ -295,21 +294,20 @@ def should_advance(self, request: Request) -> bool: assert request.structured_output_request.grammar is not None # by default, we should always advance # for cases that don't use thinking mode. - if self.reasoner is not None: - structured_req = request.structured_output_request + if self.reasoner is None: + return True - if structured_req.reasoning_ended: - return True + structured_req = request.structured_output_request + if structured_req.reasoning_ended: + return True - # Check if reasoning ends in *this* step - if self.reasoner.is_reasoning_end(request.all_token_ids): - # Reasoning just ended, so we shouldn't advance til - # next pass - structured_req.reasoning_ended = True + # Check if reasoning ends in *this* step + if self.reasoner.is_reasoning_end(request.all_token_ids): + # Reasoning just ended, so we shouldn't advance til + # next pass + structured_req.reasoning_ended = True - return False - else: - return True + return False def clear_backend(self) -> None: if self.backend is not None: diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index c37193e667aa..8e75b99f8481 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -252,7 +252,7 @@ def _process_schema( def validate_guidance_grammar( sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None ) -> None: - tp, grm = get_structured_output_key(sampling_params) + tp, grm = get_structured_output_key(sampling_params.structured_outputs) guidance_grm = serialize_guidance_grammar(tp, grm) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer) if err: diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 9e149b186c63..afe0e4b3f3a7 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -7,7 +7,7 @@ from concurrent.futures._base import TimeoutError from typing import cast -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.structured_output.backend_types import ( StructuredOutputGrammar, StructuredOutputKey, @@ -17,10 +17,19 @@ @dataclasses.dataclass class StructuredOutputRequest: - sampling_params: SamplingParams + params: StructuredOutputsParams _grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None reasoning_ended: bool | None = None + @staticmethod + def from_sampling_params( + sampling_params: SamplingParams | None, + ) -> "StructuredOutputRequest | None": + if sampling_params is None: + return None + params = sampling_params.structured_outputs + return StructuredOutputRequest(params=params) if params else None + def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports from vllm.v1.request import RequestStatus @@ -53,31 +62,28 @@ def grammar( @functools.cached_property def structured_output_key(self) -> StructuredOutputKey: - return get_structured_output_key(self.sampling_params) + return get_structured_output_key(self.params) -def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey: - params = sampling_params.structured_outputs - assert params is not None, "params can't be None." +def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey: if params.json is not None: if not isinstance(params.json, str): json_str = json.dumps(params.json) else: json_str = params.json - return (StructuredOutputOptions.JSON, json_str) - elif params.json_object: - return (StructuredOutputOptions.JSON_OBJECT, "") - elif params.regex is not None: - return (StructuredOutputOptions.REGEX, params.regex) - elif params.choice is not None: + return StructuredOutputOptions.JSON, json_str + if params.json_object: + return StructuredOutputOptions.JSON_OBJECT, "" + if params.regex is not None: + return StructuredOutputOptions.REGEX, params.regex + if params.choice is not None: if not isinstance(params.choice, str): json_str = json.dumps(params.choice) else: json_str = params.choice - return (StructuredOutputOptions.CHOICE, json_str) - elif params.grammar is not None: - return (StructuredOutputOptions.GRAMMAR, params.grammar) - elif params.structural_tag is not None: - return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag) - else: - raise ValueError("No valid structured output parameter found") + return StructuredOutputOptions.CHOICE, json_str + if params.grammar is not None: + return StructuredOutputOptions.GRAMMAR, params.grammar + if params.structural_tag is not None: + return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag + raise ValueError("No valid structured output parameter found") diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 2520dc217c79..4b793b9a72fd 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -47,7 +47,6 @@ def apply_grammar_bitmask( scheduler_output: SchedulerOutput, input_batch: InputBatch, logits: torch.Tensor, - device: torch.device, ) -> None: """ Apply grammar bitmask to output logits of the model with xgrammar function. @@ -56,7 +55,6 @@ def apply_grammar_bitmask( scheduler_output (SchedulerOutput): The result of engine scheduling. input_batch (InputBatch): The input of model runner. logits (torch.Tensor): The output logits of model forward. - device (torch.device): The device that model runner running on. """ grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is None: @@ -91,10 +89,7 @@ def apply_grammar_bitmask( dtype=grammar_bitmask.dtype, ) cumulative_index = 0 - seq = sorted( - scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1] - ) - for req_id, _ in seq: + for req_id in scheduler_output.structured_output_request_ids: num_spec_tokens = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) ) @@ -117,7 +112,7 @@ def apply_grammar_bitmask( xgr.apply_token_bitmask_inplace( logits, - grammar_bitmask.to(device, non_blocking=True), + grammar_bitmask.to(logits.device, non_blocking=True), indices=out_indices if not skip_out_indices else None, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5c2893bd0926..05e7438cd8a4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2570,10 +2570,8 @@ def execute_model( logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask( - scheduler_output, self.input_batch, logits, self.device - ) + if scheduler_output.structured_output_request_ids: + apply_grammar_bitmask(scheduler_output, self.input_batch, logits) with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 6fd71259dbcb..bd5434db9238 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1963,12 +1963,8 @@ def prepare_structured_decoding_input( self.grammar_bitmask_cpu.zero_() self.require_structured_out_cpu.zero_() - sorted_struct_requests = sorted( - scheduler_output.structured_output_request_ids.items(), - key=lambda item: item[1], - ) cumulative_mask_idx = 0 - for req_id, _ in sorted_struct_requests: + for req_id in scheduler_output.structured_output_request_ids: if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id]