Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Optional
from unittest.mock import Mock

Expand Down Expand Up @@ -1899,4 +1900,53 @@ def test_priority_scheduling_preemption_when_out_of_kv():
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
assert len(scheduler.running) == 1


@pytest.mark.parametrize(
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),
[
(True, False, True),
(False, False, False),
# Encoder-decoder models should always have it disabled
(False, True, False),
(True, True, False),
])
def test_chunked_prefill_disabled_for_encoder_decoder(
enable_chunked_prefill: bool, is_encoder_decoder: bool,
expect_enabled: bool) -> None:
"""Validate that chunked prefill is appropriately disabled for
encoder-decoder models."""
scheduler_config = SchedulerConfig(
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=is_encoder_decoder,
)

# `is_encoder_decoder` should only be used during construction
# of the config, and otherwise stored in the model config.
assert "is_encoder_decoder" not in vars(scheduler_config)
assert "is_encoder_decoder" not in [
f.name for f in dataclasses.fields(scheduler_config)
]
_validate_chunked_prefill_settings_for_encoder_decoder(
scheduler_config, is_encoder_decoder, expect_enabled)

# Ensure it is retained in VllmConfig, even after its post-init.
vllm_config = VllmConfig(scheduler_config=scheduler_config)
_validate_chunked_prefill_settings_for_encoder_decoder(
vllm_config.scheduler_config, is_encoder_decoder, expect_enabled)


def _validate_chunked_prefill_settings_for_encoder_decoder(
scheduler_config: SchedulerConfig, is_encoder_decoder: bool,
expect_enabled: bool) -> None:
"""Validate chunked prefill settings in the scheduler config for
encoder-decoder models."""
assert scheduler_config.chunked_prefill_enabled is expect_enabled
assert scheduler_config.enable_chunked_prefill is expect_enabled
if is_encoder_decoder:
# Encoder-decoder models should automatically disable chunked multimodal
# inputs as well
assert scheduler_config.disable_chunked_mm_input is not expect_enabled
if is_encoder_decoder and not expect_enabled:
assert scheduler_config.long_prefill_token_threshold == 0
21 changes: 19 additions & 2 deletions vllm/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
from dataclasses import field
from dataclasses import InitVar, field
from typing import Any, Literal, Union

from pydantic import SkipValidation, model_validator
Expand Down Expand Up @@ -84,6 +84,13 @@ class SchedulerConfig:
is_multimodal_model: bool = False
"""True if the model is multimodal."""

is_encoder_decoder: InitVar[bool] = False
"""True if the model is an encoder-decoder model.

Note: This is stored in the ModelConfig, and is used only here to
disable chunked prefill and prefix caching for encoder-decoder models.
"""

# TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = field(init=False)
"""Multimodal encoder compute budget, only used in V1.
Expand Down Expand Up @@ -161,13 +168,23 @@ def compute_hash(self) -> str:
usedforsecurity=False).hexdigest()
return hash_str

def __post_init__(self) -> None:
def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192

if self.max_num_seqs is None:
self.max_num_seqs = 128

if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
self.chunked_prefill_enabled = False
self.enable_chunked_prefill = False
self.long_prefill_token_threshold = 0
logger.info(
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both.")

if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
Expand Down
9 changes: 4 additions & 5 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,6 @@ def __post_init__(self):
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens)
self.scheduler_config.disable_chunked_mm_input = True
disable_chunked_prefill_reasons.append(
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both.")
if (self.model_config.architecture
== "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
Expand All @@ -400,7 +396,10 @@ def __post_init__(self):
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'.")

if disable_chunked_prefill_reasons:
# Disable prefix caching only if chunked prefill is explicitly disabled
# (and not merely unset)
if (self.scheduler_config.chunked_prefill_enabled is False
or disable_chunked_prefill_reasons):
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.chunked_prefill_enabled = False
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,7 @@ def create_engine_config(
enable_chunked_prefill=self.enable_chunked_prefill,
disable_chunked_mm_input=self.disable_chunked_mm_input,
is_multimodal_model=model_config.is_multimodal_model,
is_encoder_decoder=model_config.is_encoder_decoder,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
policy=self.scheduling_policy,
Expand Down