diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f6fc1e6d37d1..5e2bdaa75d3f 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses from typing import Optional from unittest.mock import Mock @@ -1899,4 +1900,53 @@ def test_priority_scheduling_preemption_when_out_of_kv(): assert output.scheduled_cached_reqs.num_reqs == 1 assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id assert len(scheduler.waiting) == 1 - assert len(scheduler.running) == 1 \ No newline at end of file + assert len(scheduler.running) == 1 + + +@pytest.mark.parametrize( + ("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"), + [ + (True, False, True), + (False, False, False), + # Encoder-decoder models should always have it disabled + (False, True, False), + (True, True, False), + ]) +def test_chunked_prefill_disabled_for_encoder_decoder( + enable_chunked_prefill: bool, is_encoder_decoder: bool, + expect_enabled: bool) -> None: + """Validate that chunked prefill is appropriately disabled for + encoder-decoder models.""" + scheduler_config = SchedulerConfig( + enable_chunked_prefill=enable_chunked_prefill, + is_encoder_decoder=is_encoder_decoder, + ) + + # `is_encoder_decoder` should only be used during construction + # of the config, and otherwise stored in the model config. + assert "is_encoder_decoder" not in vars(scheduler_config) + assert "is_encoder_decoder" not in [ + f.name for f in dataclasses.fields(scheduler_config) + ] + _validate_chunked_prefill_settings_for_encoder_decoder( + scheduler_config, is_encoder_decoder, expect_enabled) + + # Ensure it is retained in VllmConfig, even after its post-init. + vllm_config = VllmConfig(scheduler_config=scheduler_config) + _validate_chunked_prefill_settings_for_encoder_decoder( + vllm_config.scheduler_config, is_encoder_decoder, expect_enabled) + + +def _validate_chunked_prefill_settings_for_encoder_decoder( + scheduler_config: SchedulerConfig, is_encoder_decoder: bool, + expect_enabled: bool) -> None: + """Validate chunked prefill settings in the scheduler config for + encoder-decoder models.""" + assert scheduler_config.chunked_prefill_enabled is expect_enabled + assert scheduler_config.enable_chunked_prefill is expect_enabled + if is_encoder_decoder: + # Encoder-decoder models should automatically disable chunked multimodal + # inputs as well + assert scheduler_config.disable_chunked_mm_input is not expect_enabled + if is_encoder_decoder and not expect_enabled: + assert scheduler_config.long_prefill_token_threshold == 0 diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index daf094d2df5c..1b0a10d3a069 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from dataclasses import field +from dataclasses import InitVar, field from typing import Any, Literal, Union from pydantic import SkipValidation, model_validator @@ -84,6 +84,13 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" + is_encoder_decoder: InitVar[bool] = False + """True if the model is an encoder-decoder model. + + Note: This is stored in the ModelConfig, and is used only here to + disable chunked prefill and prefix caching for encoder-decoder models. + """ + # TODO (ywang96): Make this configurable. max_num_encoder_input_tokens: int = field(init=False) """Multimodal encoder compute budget, only used in V1. @@ -161,13 +168,23 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self) -> None: + def __post_init__(self, is_encoder_decoder: bool) -> None: if self.max_model_len is None: self.max_model_len = 8192 if self.max_num_seqs is None: self.max_num_seqs = 128 + if is_encoder_decoder: + # Chunked prefill should be disabled for encoder-decoder models. + self.disable_chunked_mm_input = True + self.chunked_prefill_enabled = False + self.enable_chunked_prefill = False + self.long_prefill_token_threshold = 0 + logger.info( + "Encoder-decoder models do not support chunked prefill nor" + " prefix caching; disabling both.") + if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7336f5756527..585d3997cc3a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -386,10 +386,6 @@ def __post_init__(self): "Encoder-decoder model detected: setting " "`max_num_encoder_input_tokens` to encoder length (%s)", self.scheduler_config.max_num_encoder_input_tokens) - self.scheduler_config.disable_chunked_mm_input = True - disable_chunked_prefill_reasons.append( - "Encoder-decoder models do not support chunked prefill nor" - " prefix caching; disabling both.") if (self.model_config.architecture == "WhisperForConditionalGeneration" and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") @@ -400,7 +396,10 @@ def __post_init__(self): "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " "to 'spawn'.") - if disable_chunked_prefill_reasons: + # Disable prefix caching only if chunked prefill is explicitly disabled + # (and not merely unset) + if (self.scheduler_config.chunked_prefill_enabled is False + or disable_chunked_prefill_reasons): for reason in disable_chunked_prefill_reasons: logger.info(reason) self.scheduler_config.chunked_prefill_enabled = False diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6bb794177db8..ce0f1708235f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1367,6 +1367,7 @@ def create_engine_config( enable_chunked_prefill=self.enable_chunked_prefill, disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, + is_encoder_decoder=model_config.is_encoder_decoder, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy,