diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 023b5edb2c34..c1bfe727d86e 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -251,7 +251,7 @@ def __init__( self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids # Store last speculative tokens for sampler. - self.spec_token_ids: list[list[int] | None] = [] + self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)] # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() @@ -313,7 +313,7 @@ def add_request( else: self._req_ids[req_index] = req_id self.req_output_token_ids[req_index] = request.output_token_ids - self.spec_token_ids[req_index] = [] + self.spec_token_ids[req_index].clear() self.req_id_to_index[req_id] = req_index @@ -462,7 +462,7 @@ def remove_request(self, req_id: str) -> int | None: self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None - self.spec_token_ids[req_index] = None + self.spec_token_ids[req_index].clear() # LoRA lora_id = self.request_lora_mapping[req_index] @@ -654,9 +654,15 @@ def condense(self) -> None: self.req_output_token_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index - spec_token_ids = self.spec_token_ids[last_req_index] - self.spec_token_ids[empty_index] = spec_token_ids - self.spec_token_ids[last_req_index] = None + if last_req_index != empty_index: + ( + self.spec_token_ids[last_req_index], + self.spec_token_ids[empty_index], + ) = ( + self.spec_token_ids[empty_index], + self.spec_token_ids[last_req_index], + ) + self.spec_token_ids[last_req_index].clear() num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3b00085b6bb9..0c35f1330e9f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -892,7 +892,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # conform to the schema. This can result in # scheduler_output.scheduled_spec_decode_tokens being empty, # even when speculative decoding is enabled. - self.input_batch.spec_token_ids[req_index] = spec_token_ids + self.input_batch.spec_token_ids[req_index].clear() + self.input_batch.spec_token_ids[req_index].extend(spec_token_ids) # there are no draft tokens with async scheduling, # we clear the spec_decoding info in scheduler_output and