From a329eeda472f207dc7a30ffa40fcf5444c376fff Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Wed, 17 Sep 2025 16:35:49 +0200 Subject: [PATCH 1/9] Move disabling of chunked prefill for enc-dec to SchedulerConfig Signed-off-by: simondanielsson --- vllm/config/scheduler.py | 13 +++++++++++++ vllm/config/vllm.py | 9 ++++----- vllm/engine/arg_utils.py | 1 + 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index daf094d2df5c..7da834190f1d 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -84,6 +84,9 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" + is_encoder_decoder: bool = False + """True if the model is an encoder-decoder model.""" + # TODO (ywang96): Make this configurable. max_num_encoder_input_tokens: int = field(init=False) """Multimodal encoder compute budget, only used in V1. @@ -168,6 +171,16 @@ def __post_init__(self) -> None: if self.max_num_seqs is None: self.max_num_seqs = 128 + if self.is_encoder_decoder: + # Chunked prefill should be disabled for encoder-decoder models. + self.disable_chunked_mm_input = True + self.chunked_prefill_enabled = False + self.enable_chunked_prefill = False + self.long_prefill_token_threshold = 0 + logger.info( + "Encoder-decoder models do not support chunked prefill nor" + " prefix caching; disabling both.") + if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7336f5756527..585d3997cc3a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -386,10 +386,6 @@ def __post_init__(self): "Encoder-decoder model detected: setting " "`max_num_encoder_input_tokens` to encoder length (%s)", self.scheduler_config.max_num_encoder_input_tokens) - self.scheduler_config.disable_chunked_mm_input = True - disable_chunked_prefill_reasons.append( - "Encoder-decoder models do not support chunked prefill nor" - " prefix caching; disabling both.") if (self.model_config.architecture == "WhisperForConditionalGeneration" and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") @@ -400,7 +396,10 @@ def __post_init__(self): "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " "to 'spawn'.") - if disable_chunked_prefill_reasons: + # Disable prefix caching only if chunked prefill is explicitly disabled + # (and not merely unset) + if (self.scheduler_config.chunked_prefill_enabled is False + or disable_chunked_prefill_reasons): for reason in disable_chunked_prefill_reasons: logger.info(reason) self.scheduler_config.chunked_prefill_enabled = False diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6bb794177db8..ce0f1708235f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1367,6 +1367,7 @@ def create_engine_config( enable_chunked_prefill=self.enable_chunked_prefill, disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, + is_encoder_decoder=model_config.is_encoder_decoder, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, From 5a4623189915ef7d3c9ce32e56e9667d3e80477b Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Wed, 17 Sep 2025 19:31:31 +0000 Subject: [PATCH 2/9] Add tests Signed-off-by: simondanielsson --- tests/core/test_chunked_prefill_scheduler.py | 899 +++++++++++++++++++ 1 file changed, 899 insertions(+) create mode 100644 tests/core/test_chunked_prefill_scheduler.py diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py new file mode 100644 index 000000000000..c009e691545d --- /dev/null +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -0,0 +1,899 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock + +import pytest # noqa + +from vllm.config import CacheConfig, SchedulerConfig, VllmConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob, SequenceGroup + +from .utils import create_dummy_prompt + + +def get_sequence_groups(scheduler_output): + return [s.seq_group for s in scheduler_output.scheduled_seq_groups] + + +def append_new_token(seq_group: SequenceGroup, token_id: int): + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + +def schedule_and_update_computed_tokens(scheduler): + metas, out, _ = scheduler.schedule() + for s, meta in zip(out.scheduled_seq_groups, metas): + s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + return metas, out + + +def test_simple(): + """Verify basic scheduling works.""" + block_size = 4 + num_seq_group = 4 + max_model_len = 16 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, + num_seq_group, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert out.num_batched_tokens == num_tokens + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + for s in running: + append_new_token(s, 1) + + # Schedule seq groups generation. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert out.num_batched_tokens == num_seq_group + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + + +def test_chunk(): + """Verify prefills are chunked properly.""" + block_size = 4 + max_seqs = 60 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Verify the second request is chunked. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + print() + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 60 + # Verify it is chunked. + assert seq_group_meta[1].token_chunk_size == 4 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + # Only the first seq group has a new token appended. + append_new_token(running[0], 1) + + # One chunked prefill, and one decoding. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # The first one is prefill. Scheduler guarantees ordering. + assert seq_group_meta[0].token_chunk_size == 56 + # The second one is a chunked prefill. + assert seq_group_meta[1].token_chunk_size == 1 + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 57 + + +def test_concurrent_chunking(): + """Verify prefills are chunked properly when + --max-num-partial-prefills is > 1""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Verify both requests are chunked with half of max_num_batched_tokens each + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 32 + assert seq_group_meta[1].token_chunk_size == 32 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + + # After one iteration, both should have 60 - 32 = 28 tokens left to prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 28 + assert seq_group_meta[1].token_chunk_size == 28 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 56 + + +def test_concurrent_chunking_large_requests(): + """Verify large prefill requests are run one at a time""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests + cache_config.num_gpu_blocks = 3200 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=1200, # Very large prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + + # Verify only a single request is chunked, and it gets all 64 tokens + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 1 + assert seq_group_meta[0].token_chunk_size == 64 + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + +def test_short_prompts_jump_long_prompts_in_queue(): + """Verify large prefill requests are punted behind smaller ones if + another large prefill request is already running""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests + cache_config.num_gpu_blocks = 3200 + scheduler = Scheduler(scheduler_config, cache_config, None) + long_seqs: list[SequenceGroup] = [] + short_seqs: list[SequenceGroup] = [] + + # Add 2 large seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=1200, # Very large prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + long_seqs.append(seq_group) + assert seq_group.is_prefill() + + # Add 2 small seq groups behind them + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i + 2), + prompt_length=40, # Very small prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + short_seqs.append(seq_group) + assert seq_group.is_prefill() + + # Verify one large req and 1 small req chunked + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens + assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens + + # all 4 are prefilling + assert long_seqs[0].is_prefill() + assert long_seqs[1].is_prefill() + assert short_seqs[0].is_prefill() + assert short_seqs[1].is_prefill() + # First short and first long sequences have been scheduled + assert long_seqs[0].first_seq.get_num_computed_tokens() == 32 + assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 + assert short_seqs[0].first_seq.get_num_computed_tokens() == 32 + assert short_seqs[1].first_seq.get_num_computed_tokens() == 0 + + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + + # in the second iteration, + # the first small request had only 8 tokens left + # so it went to decode + # The other small req is scheduled + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + # the new small req got 64 - (32+8) tokens + assert seq_group_meta[0].token_chunk_size == 24 + assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32 + # the other small request had only 8 tokens left + assert seq_group_meta[2].token_chunk_size == 8 # 40-32 + + # The first small request got to decode now + assert long_seqs[0].is_prefill() + assert long_seqs[1].is_prefill() + assert not short_seqs[0].is_prefill() + assert short_seqs[1].is_prefill() + # Both small requests have started in front of the second long request + assert long_seqs[0].first_seq.get_num_computed_tokens() == 64 + assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 + assert short_seqs[0].first_seq.get_num_computed_tokens() == 40 + assert short_seqs[1].first_seq.get_num_computed_tokens() == 24 + + assert out.num_prefill_groups == 3 + assert out.num_batched_tokens == 64 + # the first small seq group has a new token appended. + append_new_token(short_seqs[0], 1) + + # in the third iteration, + # the first small request is already decoding + # the second small request only has 16 tokens left and will enter decoding + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 32 # large still got 32 + # small req finished prefilling 40-24=16 tokens + assert seq_group_meta[1].token_chunk_size == 16 + assert seq_group_meta[2].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 49 # (32+16+1 decode) + + # both small requests have now reached decode + assert long_seqs[0].is_prefill() + assert long_seqs[1].is_prefill() + assert not short_seqs[0].is_prefill() + assert not short_seqs[1].is_prefill() + assert long_seqs[0].first_seq.get_num_computed_tokens() == 96 + assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 + assert short_seqs[0].first_seq.get_num_computed_tokens() == 41 + assert short_seqs[1].first_seq.get_num_computed_tokens() == 40 + + # both the small seq groups have a new token appended + append_new_token(short_seqs[0], 1) + append_new_token(short_seqs[1], 1) + + # in the fourth iteration, both small requests are decoding + # so large request gets all the budget + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + + # large req gets 62 tokens (minus 2 for decode) + assert seq_group_meta[0].token_chunk_size == 62 + assert seq_group_meta[1].token_chunk_size == 1 # decode + assert seq_group_meta[2].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + assert long_seqs[0].first_seq.get_num_computed_tokens() == 158 + + # assert long_seqs[0].is_prefill() + # assert long_seqs[1].is_prefill() + # assert not short_seqs[0].is_prefill() + # assert not short_seqs[1].is_prefill() + + # # both the small seq groups have a new token appended + # append_new_token(short_seqs[0], 1) + # append_new_token(short_seqs[1], 1) + + # # in the fifth iteration, large request gets all the budget + # # while both small requests are decoding + # seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + # assert seq_group_meta[0].token_chunk_size == 62 + # assert seq_group_meta[1].token_chunk_size == 1 # decode + # assert seq_group_meta[2].token_chunk_size == 1 # decode + # assert out.num_prefill_groups == 1 + # assert out.num_batched_tokens == 64 + + +def test_complex(): + block_size = 4 + max_seqs = 60 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 64 + cache_config.num_gpu_blocks = 64 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # Verify the second request is chunked. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 60 + # Verify it is chunked. + assert seq_group_meta[1].token_chunk_size == 4 + assert not running[0].is_prefill() + assert running[1].is_prefill() + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + # Only the first seq group has a new token appended. + append_new_token(running[0], 1) + + # Add 2 more requests. + for i in range(2, 4): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Decoding & chunked prefill & first chunk of 3rd request is scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 3 + # The first one is the first chunked prefill. + assert seq_group_meta[0].token_chunk_size == 7 + # The second one is the second new chunked prefill. + assert seq_group_meta[1].token_chunk_size == 56 + # The last one is decode. + assert seq_group_meta[2].token_chunk_size == 1 + # Two of them are in chunked prefill. + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + # The first 2 requests are now in decodine phase. + append_new_token(running[0], 1) + assert not running[0].is_prefill() + append_new_token(running[1], 1) + assert not running[1].is_prefill() + # The third request is still in prefill stage. + assert running[2].is_prefill() + + +def test_maximal_decoding(): + """Verify decoding requests are prioritized.""" + block_size = 4 + max_seqs = 2 + max_model_len = 8 + max_num_batched_tokens = 2 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=2, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # The first prefill is scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 1 + assert seq_group_meta[0].token_chunk_size == 2 + assert not running[0].is_prefill() + assert running[1].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 2 + # Only the first seq group has a new token appended. + append_new_token(running[0], 1) + + # Create one more seq_group. + _, seq_group = create_dummy_prompt("3", + prompt_length=2, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + # The first decoding + second chunk is scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 2 + assert seq_group_meta[0].token_chunk_size == 1 + assert seq_group_meta[1].token_chunk_size == 1 + assert not running[0].is_prefill() + assert running[1].is_prefill() + assert running[2].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 2 + append_new_token(running[0], 1) + + # Decoding + running prefill is prioritized. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 2 + assert seq_group_meta[0].token_chunk_size == 1 + assert seq_group_meta[1].token_chunk_size == 1 + assert not running[0].is_prefill() + assert not running[1].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 2 + append_new_token(running[0], 1) + append_new_token(running[1], 1) + + # Only decoding is prioritized. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 2 + assert seq_group_meta[0].token_chunk_size == 1 + assert seq_group_meta[1].token_chunk_size == 1 + assert not running[0].is_prefill() + assert not running[1].is_prefill() + assert out.num_prefill_groups == 0 + assert out.num_batched_tokens == 2 + append_new_token(running[0], 1) + append_new_token(running[1], 1) + + # After aborting the decoding request, the fcfs new prefill is prioritized. + scheduler.abort_seq_group(running[0].request_id) + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 2 + assert seq_group_meta[0].token_chunk_size == 1 + assert seq_group_meta[1].token_chunk_size == 1 + assert not running[1].is_prefill() + assert running[2].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 2 + + +def test_prompt_limit(): + """Verify max_num_batched_tokens < max_model_len is possible.""" + block_size = 4 + max_seqs = 32 + max_model_len = 64 + max_num_batched_tokens = 32 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + _, seq_group = create_dummy_prompt("1", + prompt_length=48, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + + # The prompt length > max_num_batched_tokens should be still scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 1 + assert seq_group_meta[0].token_chunk_size == 32 + assert running[0].is_prefill() + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 32 + + +def test_prompt_limit_exceed(): + block_size = 4 + max_seqs = 64 + max_model_len = 32 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + _, seq_group = create_dummy_prompt("2", + prompt_length=48, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.ignored_seq_groups) == 1 + assert out.ignored_seq_groups[0] == seq_group + + +def test_chunked_prefill_preempt(): + """Verify preempt works with chunked prefill requests""" + block_size = 4 + max_seqs = 30 + max_model_len = 200 + max_num_batched_tokens = 30 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 + scheduler = Scheduler(scheduler_config, cache_config, None) + + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + _, out = schedule_and_update_computed_tokens(scheduler) + # The request is chunked. + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + + # The request should be preempted. + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group1(seq_group, num_lookahead_slots): + return seq_group.request_id != "1" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group1) + + # The running prefill is now preempted. + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 0 + assert out.num_batched_tokens == 0 + assert out.blocks_to_swap_out == [] + assert out.blocks_to_swap_in == [] + + # Make sure we can reschedule preempted request. + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + assert seq_group.get_num_uncomputed_tokens() == 30 + + # We should be able to run prefill twice as it is chunked. + def cannot_append_second_group2(seq_group, num_lookahead_slots): + return True + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group2) + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert not seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + + +def test_chunked_prefill_spec_prefill(): + """Verify that the num_lookahead_slots is set appropriately for an all""" + """prefill batch.""" + block_size = 4 + max_seqs = 30 + max_model_len = 200 + max_num_batched_tokens = 30 + num_lookahead_slots = 4 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + num_lookahead_slots=num_lookahead_slots, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 + scheduler = Scheduler(scheduler_config, cache_config, None) + + _, seq_group = create_dummy_prompt("1", + prompt_length=30, + block_size=block_size) + scheduler.add_seq_group(seq_group) + _, out = schedule_and_update_computed_tokens(scheduler) + # The request is chunked. + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == max_num_batched_tokens + print(out.num_lookahead_slots) + assert out.num_lookahead_slots == 0 + + +def test_chunked_prefill_max_seqs(): + block_size = 4 + max_seqs = 2 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 128 + cache_config.num_gpu_blocks = 128 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + _, seq_group = create_dummy_prompt("1", + prompt_length=65, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + # The first prefill is chunked. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens + assert len(get_sequence_groups(out)) == 1 + + # Add new requests. + for i in range(4): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=65, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Make sure only 2 requests are scheduled. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert out.num_batched_tokens == max_num_batched_tokens + assert len(get_sequence_groups(out)) == 2 + assert not running[0].is_prefill() + assert running[1].is_prefill() + append_new_token(running[0], 1) + + # Although we have enough token budget, we can only schedule max_seqs. + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 2 + assert seq_group_meta[1].token_chunk_size == 1 + assert out.num_batched_tokens == 3 + assert len(get_sequence_groups(out)) == max_seqs + assert not running[0].is_prefill() + assert not running[1].is_prefill() + + +def test_prefix_caching(): + """Verify allocating full blocks when prefix caching is enabled.""" + block_size = 4 + max_seqs = 10 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + ) + cache_config = CacheConfig(block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=True) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + block_size=block_size, + prompt_length=50) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 50 + # Verify it is chunked. Note that although the budget is 64-50=14, + # we only allocate full blocks for prefix caching, so only 4*(14//4)=12 + # tokens are allocated. + assert seq_group_meta[1].token_chunk_size == 12 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 62 + + +def test_prefix_caching_with_concurrent_partial_prefills(): + """Verify allocating full blocks when prefix caching is enabled with + --max-num-partial-prefills > 1.""" + block_size = 4 + max_seqs = 10 + max_model_len = 8000 + max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + max_num_partial_prefills=2) + cache_config = CacheConfig(block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=True) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: list[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + block_size=block_size, + prompt_length=50) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # To partially prefill both sequences, both can chunk up to 30 tokens + # But the next lowest multiple of the block size (4) is 28 + assert seq_group_meta[0].token_chunk_size == 28 + assert seq_group_meta[1].token_chunk_size == 28 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 56 + + # On the next iteration, both sequences should finish prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # Both sequences have 50 - 28 = 22 tokens left to prefill. + # This is not a multiple of the block size, but we don't care since we don't + # cache the final partial block of prefix sequences + assert seq_group_meta[0].token_chunk_size == 22 + assert seq_group_meta[1].token_chunk_size == 22 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 44 + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) +def test_chunked_prefill_with_actual_engine(model: str, + max_num_partial_prefills: int): + """Make sure the model can actually sample with concurrent + partial prefills + """ + + prompt = "hello" * 40 + + engine_args = EngineArgs( + model=model, + max_num_partial_prefills=max_num_partial_prefills, + max_num_batched_tokens=40, + max_num_seqs=8, + enable_chunked_prefill=True, + gpu_memory_utilization=0.8, + ) + + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(temperature=0) + + for req_num in range(max_num_partial_prefills): + engine.add_request(f"{req_num}", prompt, sampling_params) + # first step + request_outputs = engine.step() + # means all are prefilling + assert len(request_outputs) == 0 + assert len(engine.scheduler[0].running) == max_num_partial_prefills + + +@pytest.mark.parametrize( + ("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"), [ + (True, False, True), + (False, False, False), + (False, True, False), + (True, True, False), + ]) +def test_chunked_prefill_disabled_for_encoder_decoder( + enable_chunked_prefill: bool, is_encoder_decoder: bool, + expect_enabled: bool) -> None: + """Validate that chunked prefill is appropriately disabled for + encoder-decoder models.""" + scheduler_config = SchedulerConfig( + enable_chunked_prefill=enable_chunked_prefill, + is_encoder_decoder=is_encoder_decoder, + ) + + _validate_chunked_prefill_settings_for_encoder_decoder( + scheduler_config, is_encoder_decoder, expect_enabled) + + # Ensure it is retained in VllmConfig, even after its post-init. + vllm_config = VllmConfig(scheduler_config=scheduler_config) + _validate_chunked_prefill_settings_for_encoder_decoder( + vllm_config.scheduler_config, is_encoder_decoder, expect_enabled) + + +def _validate_chunked_prefill_settings_for_encoder_decoder( + scheduler_config: SchedulerConfig, is_encoder_decoder: bool, + expect_enabled: bool) -> None: + """Validate chunked prefill settings in the scheduler config for + encoder-decoder models.""" + assert scheduler_config.chunked_prefill_enabled is expect_enabled + assert scheduler_config.enable_chunked_prefill is expect_enabled + if is_encoder_decoder: + # Encoder-decoder models should automatically disable chunked multimodal + # inputs as well + assert scheduler_config.disable_chunked_mm_input is not expect_enabled + if is_encoder_decoder and not expect_enabled: + assert scheduler_config.long_prefill_token_threshold == 0 From fd6d12fe060a490a05118d57cff0ec76df0cf80f Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Wed, 17 Sep 2025 19:53:41 +0000 Subject: [PATCH 3/9] Move test to v1 package Signed-off-by: simondanielsson --- tests/core/test_chunked_prefill_scheduler.py | 899 ------------------- tests/v1/core/test_scheduler.py | 45 +- 2 files changed, 44 insertions(+), 900 deletions(-) delete mode 100644 tests/core/test_chunked_prefill_scheduler.py diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py deleted file mode 100644 index c009e691545d..000000000000 --- a/tests/core/test_chunked_prefill_scheduler.py +++ /dev/null @@ -1,899 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from unittest.mock import MagicMock - -import pytest # noqa - -from vllm.config import CacheConfig, SchedulerConfig, VllmConfig -from vllm.core.scheduler import Scheduler -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob, SequenceGroup - -from .utils import create_dummy_prompt - - -def get_sequence_groups(scheduler_output): - return [s.seq_group for s in scheduler_output.scheduled_seq_groups] - - -def append_new_token(seq_group: SequenceGroup, token_id: int): - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def schedule_and_update_computed_tokens(scheduler): - metas, out, _ = scheduler.schedule() - for s, meta in zip(out.scheduled_seq_groups, metas): - s.seq_group.update_num_computed_tokens(meta.token_chunk_size) - return metas, out - - -def test_simple(): - """Verify basic scheduling works.""" - block_size = 4 - num_seq_group = 4 - max_model_len = 16 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prompts. - num_tokens = block_size * num_seq_group - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_tokens - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - for s in running: - append_new_token(s, 1) - - # Schedule seq groups generation. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_seq_group - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - - -def test_chunk(): - """Verify prefills are chunked properly.""" - block_size = 4 - max_seqs = 60 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Verify the second request is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - print() - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 60 - # Verify it is chunked. - assert seq_group_meta[1].token_chunk_size == 4 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # One chunked prefill, and one decoding. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # The first one is prefill. Scheduler guarantees ordering. - assert seq_group_meta[0].token_chunk_size == 56 - # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 1 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 57 - - -def test_concurrent_chunking(): - """Verify prefills are chunked properly when - --max-num-partial-prefills is > 1""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Verify both requests are chunked with half of max_num_batched_tokens each - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 32 - assert seq_group_meta[1].token_chunk_size == 32 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - - # After one iteration, both should have 60 - 32 = 28 tokens left to prefill - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 28 - assert seq_group_meta[1].token_chunk_size == 28 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 56 - - -def test_concurrent_chunking_large_requests(): - """Verify large prefill requests are run one at a time""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests - cache_config.num_gpu_blocks = 3200 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i), - prompt_length=1200, # Very large prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - - # Verify only a single request is chunked, and it gets all 64 tokens - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 64 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 64 - - -def test_short_prompts_jump_long_prompts_in_queue(): - """Verify large prefill requests are punted behind smaller ones if - another large prefill request is already running""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests - cache_config.num_gpu_blocks = 3200 - scheduler = Scheduler(scheduler_config, cache_config, None) - long_seqs: list[SequenceGroup] = [] - short_seqs: list[SequenceGroup] = [] - - # Add 2 large seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i), - prompt_length=1200, # Very large prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - long_seqs.append(seq_group) - assert seq_group.is_prefill() - - # Add 2 small seq groups behind them - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i + 2), - prompt_length=40, # Very small prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - short_seqs.append(seq_group) - assert seq_group.is_prefill() - - # Verify one large req and 1 small req chunked - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens - assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens - - # all 4 are prefilling - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert short_seqs[0].is_prefill() - assert short_seqs[1].is_prefill() - # First short and first long sequences have been scheduled - assert long_seqs[0].first_seq.get_num_computed_tokens() == 32 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 32 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 0 - - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - - # in the second iteration, - # the first small request had only 8 tokens left - # so it went to decode - # The other small req is scheduled - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # the new small req got 64 - (32+8) tokens - assert seq_group_meta[0].token_chunk_size == 24 - assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32 - # the other small request had only 8 tokens left - assert seq_group_meta[2].token_chunk_size == 8 # 40-32 - - # The first small request got to decode now - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert not short_seqs[0].is_prefill() - assert short_seqs[1].is_prefill() - # Both small requests have started in front of the second long request - assert long_seqs[0].first_seq.get_num_computed_tokens() == 64 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 40 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 24 - - assert out.num_prefill_groups == 3 - assert out.num_batched_tokens == 64 - # the first small seq group has a new token appended. - append_new_token(short_seqs[0], 1) - - # in the third iteration, - # the first small request is already decoding - # the second small request only has 16 tokens left and will enter decoding - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 32 # large still got 32 - # small req finished prefilling 40-24=16 tokens - assert seq_group_meta[1].token_chunk_size == 16 - assert seq_group_meta[2].token_chunk_size == 1 # decode - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 49 # (32+16+1 decode) - - # both small requests have now reached decode - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert not short_seqs[0].is_prefill() - assert not short_seqs[1].is_prefill() - assert long_seqs[0].first_seq.get_num_computed_tokens() == 96 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 41 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 40 - - # both the small seq groups have a new token appended - append_new_token(short_seqs[0], 1) - append_new_token(short_seqs[1], 1) - - # in the fourth iteration, both small requests are decoding - # so large request gets all the budget - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - - # large req gets 62 tokens (minus 2 for decode) - assert seq_group_meta[0].token_chunk_size == 62 - assert seq_group_meta[1].token_chunk_size == 1 # decode - assert seq_group_meta[2].token_chunk_size == 1 # decode - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 64 - - assert long_seqs[0].first_seq.get_num_computed_tokens() == 158 - - # assert long_seqs[0].is_prefill() - # assert long_seqs[1].is_prefill() - # assert not short_seqs[0].is_prefill() - # assert not short_seqs[1].is_prefill() - - # # both the small seq groups have a new token appended - # append_new_token(short_seqs[0], 1) - # append_new_token(short_seqs[1], 1) - - # # in the fifth iteration, large request gets all the budget - # # while both small requests are decoding - # seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # assert seq_group_meta[0].token_chunk_size == 62 - # assert seq_group_meta[1].token_chunk_size == 1 # decode - # assert seq_group_meta[2].token_chunk_size == 1 # decode - # assert out.num_prefill_groups == 1 - # assert out.num_batched_tokens == 64 - - -def test_complex(): - block_size = 4 - max_seqs = 60 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 64 - cache_config.num_gpu_blocks = 64 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # Verify the second request is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 60 - # Verify it is chunked. - assert seq_group_meta[1].token_chunk_size == 4 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # Add 2 more requests. - for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Decoding & chunked prefill & first chunk of 3rd request is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 3 - # The first one is the first chunked prefill. - assert seq_group_meta[0].token_chunk_size == 7 - # The second one is the second new chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 - # The last one is decode. - assert seq_group_meta[2].token_chunk_size == 1 - # Two of them are in chunked prefill. - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # The first 2 requests are now in decodine phase. - append_new_token(running[0], 1) - assert not running[0].is_prefill() - append_new_token(running[1], 1) - assert not running[1].is_prefill() - # The third request is still in prefill stage. - assert running[2].is_prefill() - - -def test_maximal_decoding(): - """Verify decoding requests are prioritized.""" - block_size = 4 - max_seqs = 2 - max_model_len = 8 - max_num_batched_tokens = 2 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # The first prefill is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 2 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # Create one more seq_group. - _, seq_group = create_dummy_prompt("3", - prompt_length=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - # The first decoding + second chunk is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert running[2].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - - # Decoding + running prefill is prioritized. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert not running[1].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - append_new_token(running[1], 1) - - # Only decoding is prioritized. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert not running[1].is_prefill() - assert out.num_prefill_groups == 0 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - append_new_token(running[1], 1) - - # After aborting the decoding request, the fcfs new prefill is prioritized. - scheduler.abort_seq_group(running[0].request_id) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[1].is_prefill() - assert running[2].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - - -def test_prompt_limit(): - """Verify max_num_batched_tokens < max_model_len is possible.""" - block_size = 4 - max_seqs = 32 - max_model_len = 64 - max_num_batched_tokens = 32 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("1", - prompt_length=48, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # The prompt length > max_num_batched_tokens should be still scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 32 - assert running[0].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 32 - - -def test_prompt_limit_exceed(): - block_size = 4 - max_seqs = 64 - max_model_len = 32 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("2", - prompt_length=48, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.ignored_seq_groups) == 1 - assert out.ignored_seq_groups[0] == seq_group - - -def test_chunked_prefill_preempt(): - """Verify preempt works with chunked prefill requests""" - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - # The request should be preempted. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group1(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group1) - - # The running prefill is now preempted. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 0 - assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out == [] - assert out.blocks_to_swap_in == [] - - # Make sure we can reschedule preempted request. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - assert seq_group.get_num_uncomputed_tokens() == 30 - - # We should be able to run prefill twice as it is chunked. - def cannot_append_second_group2(seq_group, num_lookahead_slots): - return True - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group2) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert not seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - -def test_chunked_prefill_spec_prefill(): - """Verify that the num_lookahead_slots is set appropriately for an all""" - """prefill batch.""" - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - num_lookahead_slots = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - num_lookahead_slots=num_lookahead_slots, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=30, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == max_num_batched_tokens - print(out.num_lookahead_slots) - assert out.num_lookahead_slots == 0 - - -def test_chunked_prefill_max_seqs(): - block_size = 4 - max_seqs = 2 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 128 - cache_config.num_gpu_blocks = 128 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("1", - prompt_length=65, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - # The first prefill is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens - assert len(get_sequence_groups(out)) == 1 - - # Add new requests. - for i in range(4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=65, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Make sure only 2 requests are scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_batched_tokens == max_num_batched_tokens - assert len(get_sequence_groups(out)) == 2 - assert not running[0].is_prefill() - assert running[1].is_prefill() - append_new_token(running[0], 1) - - # Although we have enough token budget, we can only schedule max_seqs. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 2 - assert seq_group_meta[1].token_chunk_size == 1 - assert out.num_batched_tokens == 3 - assert len(get_sequence_groups(out)) == max_seqs - assert not running[0].is_prefill() - assert not running[1].is_prefill() - - -def test_prefix_caching(): - """Verify allocating full blocks when prefix caching is enabled.""" - block_size = 4 - max_seqs = 10 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=True) - cache_config.num_cpu_blocks = 0 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - block_size=block_size, - prompt_length=50) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 50 - # Verify it is chunked. Note that although the budget is 64-50=14, - # we only allocate full blocks for prefix caching, so only 4*(14//4)=12 - # tokens are allocated. - assert seq_group_meta[1].token_chunk_size == 12 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 62 - - -def test_prefix_caching_with_concurrent_partial_prefills(): - """Verify allocating full blocks when prefix caching is enabled with - --max-num-partial-prefills > 1.""" - block_size = 4 - max_seqs = 10 - max_model_len = 8000 - max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2) - cache_config = CacheConfig(block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=True) - cache_config.num_cpu_blocks = 0 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - block_size=block_size, - prompt_length=50) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # To partially prefill both sequences, both can chunk up to 30 tokens - # But the next lowest multiple of the block size (4) is 28 - assert seq_group_meta[0].token_chunk_size == 28 - assert seq_group_meta[1].token_chunk_size == 28 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 56 - - # On the next iteration, both sequences should finish prefill - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # Both sequences have 50 - 28 = 22 tokens left to prefill. - # This is not a multiple of the block size, but we don't care since we don't - # cache the final partial block of prefix sequences - assert seq_group_meta[0].token_chunk_size == 22 - assert seq_group_meta[1].token_chunk_size == 22 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 44 - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) -def test_chunked_prefill_with_actual_engine(model: str, - max_num_partial_prefills: int): - """Make sure the model can actually sample with concurrent - partial prefills - """ - - prompt = "hello" * 40 - - engine_args = EngineArgs( - model=model, - max_num_partial_prefills=max_num_partial_prefills, - max_num_batched_tokens=40, - max_num_seqs=8, - enable_chunked_prefill=True, - gpu_memory_utilization=0.8, - ) - - engine = LLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(temperature=0) - - for req_num in range(max_num_partial_prefills): - engine.add_request(f"{req_num}", prompt, sampling_params) - # first step - request_outputs = engine.step() - # means all are prefilling - assert len(request_outputs) == 0 - assert len(engine.scheduler[0].running) == max_num_partial_prefills - - -@pytest.mark.parametrize( - ("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"), [ - (True, False, True), - (False, False, False), - (False, True, False), - (True, True, False), - ]) -def test_chunked_prefill_disabled_for_encoder_decoder( - enable_chunked_prefill: bool, is_encoder_decoder: bool, - expect_enabled: bool) -> None: - """Validate that chunked prefill is appropriately disabled for - encoder-decoder models.""" - scheduler_config = SchedulerConfig( - enable_chunked_prefill=enable_chunked_prefill, - is_encoder_decoder=is_encoder_decoder, - ) - - _validate_chunked_prefill_settings_for_encoder_decoder( - scheduler_config, is_encoder_decoder, expect_enabled) - - # Ensure it is retained in VllmConfig, even after its post-init. - vllm_config = VllmConfig(scheduler_config=scheduler_config) - _validate_chunked_prefill_settings_for_encoder_decoder( - vllm_config.scheduler_config, is_encoder_decoder, expect_enabled) - - -def _validate_chunked_prefill_settings_for_encoder_decoder( - scheduler_config: SchedulerConfig, is_encoder_decoder: bool, - expect_enabled: bool) -> None: - """Validate chunked prefill settings in the scheduler config for - encoder-decoder models.""" - assert scheduler_config.chunked_prefill_enabled is expect_enabled - assert scheduler_config.enable_chunked_prefill is expect_enabled - if is_encoder_decoder: - # Encoder-decoder models should automatically disable chunked multimodal - # inputs as well - assert scheduler_config.disable_chunked_mm_input is not expect_enabled - if is_encoder_decoder and not expect_enabled: - assert scheduler_config.long_prefill_token_threshold == 0 diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f6fc1e6d37d1..e498b3cd0c6d 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1899,4 +1899,47 @@ def test_priority_scheduling_preemption_when_out_of_kv(): assert output.scheduled_cached_reqs.num_reqs == 1 assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id assert len(scheduler.waiting) == 1 - assert len(scheduler.running) == 1 \ No newline at end of file + assert len(scheduler.running) == 1 + + +@pytest.mark.parametrize( + ("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"), + [ + (True, False, True), + (False, False, False), + # Encoder-decoder models should always have it disabled + (False, True, False), + (True, True, False), + ]) +def test_chunked_prefill_disabled_for_encoder_decoder( + enable_chunked_prefill: bool, is_encoder_decoder: bool, + expect_enabled: bool) -> None: + """Validate that chunked prefill is appropriately disabled for + encoder-decoder models.""" + scheduler_config = SchedulerConfig( + enable_chunked_prefill=enable_chunked_prefill, + is_encoder_decoder=is_encoder_decoder, + ) + + _validate_chunked_prefill_settings_for_encoder_decoder( + scheduler_config, is_encoder_decoder, expect_enabled) + + # Ensure it is retained in VllmConfig, even after its post-init. + vllm_config = VllmConfig(scheduler_config=scheduler_config) + _validate_chunked_prefill_settings_for_encoder_decoder( + vllm_config.scheduler_config, is_encoder_decoder, expect_enabled) + + +def _validate_chunked_prefill_settings_for_encoder_decoder( + scheduler_config: SchedulerConfig, is_encoder_decoder: bool, + expect_enabled: bool) -> None: + """Validate chunked prefill settings in the scheduler config for + encoder-decoder models.""" + assert scheduler_config.chunked_prefill_enabled is expect_enabled + assert scheduler_config.enable_chunked_prefill is expect_enabled + if is_encoder_decoder: + # Encoder-decoder models should automatically disable chunked multimodal + # inputs as well + assert scheduler_config.disable_chunked_mm_input is not expect_enabled + if is_encoder_decoder and not expect_enabled: + assert scheduler_config.long_prefill_token_threshold == 0 From a9882193be4fc9872f4a78682126157135565a00 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Thu, 18 Sep 2025 15:02:01 +0200 Subject: [PATCH 4/9] Make is_encoder_decoder an init var Signed-off-by: simondanielsson --- tests/v1/core/test_scheduler.py | 2 + vllm/config/scheduler.py | 91 ++++++++++++++++++++------------- 2 files changed, 57 insertions(+), 36 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index e498b3cd0c6d..7eb8084eac44 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1921,6 +1921,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder( is_encoder_decoder=is_encoder_decoder, ) + # `is_encoder_decoder` should only be used during construction of the config + assert not hasattr(scheduler_config, "is_encoder_decoder") _validate_chunked_prefill_settings_for_encoder_decoder( scheduler_config, is_encoder_decoder, expect_enabled) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 7da834190f1d..396258aac287 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from dataclasses import field +from dataclasses import InitVar, field from typing import Any, Literal, Union from pydantic import SkipValidation, model_validator @@ -11,9 +11,11 @@ from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS) +from vllm.utils import ( + DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, +) logger = init_logger(__name__) @@ -84,8 +86,12 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" - is_encoder_decoder: bool = False - """True if the model is an encoder-decoder model.""" + is_encoder_decoder: InitVar[bool] = False + """True if the model is an encoder-decoder model. + + Note: This is stored in the ModelConfig, and is used only here to + disable chunked prefill and prefix caching for encoder-decoder models. + """ # TODO (ywang96): Make this configurable. max_num_encoder_input_tokens: int = field(init=False) @@ -160,18 +166,17 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self) -> None: + def __post_init__(self, is_encoder_decoder: bool) -> None: if self.max_model_len is None: self.max_model_len = 8192 if self.max_num_seqs is None: self.max_num_seqs = 128 - if self.is_encoder_decoder: + if is_encoder_decoder: # Chunked prefill should be disabled for encoder-decoder models. self.disable_chunked_mm_input = True self.chunked_prefill_enabled = False @@ -179,7 +184,8 @@ def __post_init__(self) -> None: self.long_prefill_token_threshold = 0 logger.info( "Encoder-decoder models do not support chunked prefill nor" - " prefix caching; disabling both.") + " prefix caching; disabling both." + ) if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: @@ -189,7 +195,8 @@ def __post_init__(self) -> None: # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value # for higher throughput. self.max_num_batched_tokens = max( - self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS + ) if self.runner_type == "pooling": # Choose specific value for higher throughput @@ -208,8 +215,8 @@ def __post_init__(self) -> None: # Ensure max_num_batched_tokens does not exceed model limit. # Some models (e.g., Whisper) have embeddings tied to max length. self.max_num_batched_tokens = min( - self.max_num_seqs * self.max_model_len, - self.max_num_batched_tokens) + self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens + ) self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens @@ -217,20 +224,22 @@ def __post_init__(self) -> None: if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", - self.max_num_batched_tokens) + self.max_num_batched_tokens, + ) self.chunked_prefill_enabled = self.enable_chunked_prefill if self.max_num_partial_prefills > 1: if self.long_prefill_token_threshold == 0: - self.long_prefill_token_threshold = int(self.max_model_len * - 0.04) + self.long_prefill_token_threshold = int(self.max_model_len * 0.04) logger.info( "Concurrent partial prefills enabled with " "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " "long_prefill_token_threshold=%d", - self.max_num_partial_prefills, self.max_long_partial_prefills, - self.long_prefill_token_threshold) + self.max_num_partial_prefills, + self.max_long_partial_prefills, + self.long_prefill_token_threshold, + ) # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. # This avoids OOM in tight memory scenarios with small max_num_seqs, @@ -240,61 +249,71 @@ def __post_init__(self) -> None: self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] if self.async_scheduling: - self.scheduler_cls = ( - "vllm.v1.core.sched.async_scheduler.AsyncScheduler") + self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" - @model_validator(mode='after') + @model_validator(mode="after") def _verify_args(self) -> Self: - if (self.max_num_batched_tokens < self.max_model_len - and not self.chunked_prefill_enabled): + if ( + self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled + ): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") + "decrease max_model_len." + ) if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") + f"({self.max_num_seqs})." + ) if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: logger.warning( "max_num_batched_tokens (%d) exceeds max_num_seqs " "* max_model_len (%d). This may lead to unexpected behavior.", self.max_num_batched_tokens, - self.max_num_seqs * self.max_model_len) + self.max_num_seqs * self.max_model_len, + ) if self.num_lookahead_slots < 0: raise ValueError( "num_lookahead_slots " f"({self.num_lookahead_slots}) must be greater than or " - "equal to 0.") + "equal to 0." + ) if self.max_num_partial_prefills < 1: raise ValueError( f"max_num_partial_prefills ({self.max_num_partial_prefills}) " - "must be greater than or equal to 1.") + "must be greater than or equal to 1." + ) elif self.max_num_partial_prefills > 1: if not self.chunked_prefill_enabled: - raise ValueError("Chunked prefill must be enabled to set " - "max_num_partial_prefills > 1.") + raise ValueError( + "Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1." + ) if self.long_prefill_token_threshold > self.max_model_len: raise ValueError( "long_prefill_token_threshold " f"({self.long_prefill_token_threshold}) cannot be greater " - f"than the max_model_len ({self.max_model_len}).") + f"than the max_model_len ({self.max_model_len})." + ) - if (self.max_long_partial_prefills - < 1) or (self.max_long_partial_prefills - > self.max_num_partial_prefills): + if (self.max_long_partial_prefills < 1) or ( + self.max_long_partial_prefills > self.max_num_partial_prefills + ): raise ValueError( f"max_long_partial_prefills ({self.max_long_partial_prefills}) " "must be greater than or equal to 1 and less than or equal to " - f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + f"max_num_partial_prefills ({self.max_num_partial_prefills})." + ) return self From 0e3a5e9735a3c60744681cb1980ceb7ba12220b3 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Fri, 26 Sep 2025 20:47:44 +0000 Subject: [PATCH 5/9] Run pre-commit Signed-off-by: simondanielsson --- vllm/config/scheduler.py | 63 +++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 396258aac287..e8cc0f5d555e 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -11,11 +11,9 @@ from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import ( - DEFAULT_MAX_NUM_BATCHED_TOKENS, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, -) +from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS) logger = init_logger(__name__) @@ -166,7 +164,8 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() return hash_str def __post_init__(self, is_encoder_decoder: bool) -> None: @@ -184,8 +183,7 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.long_prefill_token_threshold = 0 logger.info( "Encoder-decoder models do not support chunked prefill nor" - " prefix caching; disabling both." - ) + " prefix caching; disabling both.") if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: @@ -195,8 +193,7 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value # for higher throughput. self.max_num_batched_tokens = max( - self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS - ) + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.runner_type == "pooling": # Choose specific value for higher throughput @@ -215,8 +212,8 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: # Ensure max_num_batched_tokens does not exceed model limit. # Some models (e.g., Whisper) have embeddings tied to max length. self.max_num_batched_tokens = min( - self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens - ) + self.max_num_seqs * self.max_model_len, + self.max_num_batched_tokens) self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens @@ -230,7 +227,8 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.chunked_prefill_enabled = self.enable_chunked_prefill if self.max_num_partial_prefills > 1: if self.long_prefill_token_threshold == 0: - self.long_prefill_token_threshold = int(self.max_model_len * 0.04) + self.long_prefill_token_threshold = int(self.max_model_len * + 0.04) logger.info( "Concurrent partial prefills enabled with " @@ -249,29 +247,26 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] if self.async_scheduling: - self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" + self.scheduler_cls = ( + "vllm.v1.core.sched.async_scheduler.AsyncScheduler") @model_validator(mode="after") def _verify_args(self) -> Self: - if ( - self.max_num_batched_tokens < self.max_model_len - and not self.chunked_prefill_enabled - ): + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len." - ) + "decrease max_model_len.") if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs})." - ) + f"({self.max_num_seqs}).") if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: logger.warning( @@ -285,35 +280,29 @@ def _verify_args(self) -> Self: raise ValueError( "num_lookahead_slots " f"({self.num_lookahead_slots}) must be greater than or " - "equal to 0." - ) + "equal to 0.") if self.max_num_partial_prefills < 1: raise ValueError( f"max_num_partial_prefills ({self.max_num_partial_prefills}) " - "must be greater than or equal to 1." - ) + "must be greater than or equal to 1.") elif self.max_num_partial_prefills > 1: if not self.chunked_prefill_enabled: - raise ValueError( - "Chunked prefill must be enabled to set " - "max_num_partial_prefills > 1." - ) + raise ValueError("Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1.") if self.long_prefill_token_threshold > self.max_model_len: raise ValueError( "long_prefill_token_threshold " f"({self.long_prefill_token_threshold}) cannot be greater " - f"than the max_model_len ({self.max_model_len})." - ) + f"than the max_model_len ({self.max_model_len}).") - if (self.max_long_partial_prefills < 1) or ( - self.max_long_partial_prefills > self.max_num_partial_prefills - ): + if (self.max_long_partial_prefills + < 1) or (self.max_long_partial_prefills + > self.max_num_partial_prefills): raise ValueError( f"max_long_partial_prefills ({self.max_long_partial_prefills}) " "must be greater than or equal to 1 and less than or equal to " - f"max_num_partial_prefills ({self.max_num_partial_prefills})." - ) + f"max_num_partial_prefills ({self.max_num_partial_prefills}).") return self From 6e7208381a2ae30c3a95030b872e8bfb8458fc2c Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Fri, 26 Sep 2025 20:50:09 +0000 Subject: [PATCH 6/9] Fix formatting Signed-off-by: simondanielsson --- vllm/config/scheduler.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index e8cc0f5d555e..1b0a10d3a069 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -221,8 +221,7 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", - self.max_num_batched_tokens, - ) + self.max_num_batched_tokens) self.chunked_prefill_enabled = self.enable_chunked_prefill if self.max_num_partial_prefills > 1: @@ -234,10 +233,8 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: "Concurrent partial prefills enabled with " "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " "long_prefill_token_threshold=%d", - self.max_num_partial_prefills, - self.max_long_partial_prefills, - self.long_prefill_token_threshold, - ) + self.max_num_partial_prefills, self.max_long_partial_prefills, + self.long_prefill_token_threshold) # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. # This avoids OOM in tight memory scenarios with small max_num_seqs, @@ -250,7 +247,7 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.scheduler_cls = ( "vllm.v1.core.sched.async_scheduler.AsyncScheduler") - @model_validator(mode="after") + @model_validator(mode='after') def _verify_args(self) -> Self: if (self.max_num_batched_tokens < self.max_model_len and not self.chunked_prefill_enabled): @@ -273,8 +270,7 @@ def _verify_args(self) -> Self: "max_num_batched_tokens (%d) exceeds max_num_seqs " "* max_model_len (%d). This may lead to unexpected behavior.", self.max_num_batched_tokens, - self.max_num_seqs * self.max_model_len, - ) + self.max_num_seqs * self.max_model_len) if self.num_lookahead_slots < 0: raise ValueError( From a5f9caecc92c78f6fa0df03a51bde26a132facdc Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Sat, 27 Sep 2025 09:31:09 +0200 Subject: [PATCH 7/9] Make test compatible with pydantic.dataclasses Signed-off-by: simondanielsson --- tests/v1/core/test_scheduler.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 7eb8084eac44..5e2bdaa75d3f 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses from typing import Optional from unittest.mock import Mock @@ -1921,8 +1922,12 @@ def test_chunked_prefill_disabled_for_encoder_decoder( is_encoder_decoder=is_encoder_decoder, ) - # `is_encoder_decoder` should only be used during construction of the config - assert not hasattr(scheduler_config, "is_encoder_decoder") + # `is_encoder_decoder` should only be used during construction + # of the config, and otherwise stored in the model config. + assert "is_encoder_decoder" not in vars(scheduler_config) + assert "is_encoder_decoder" not in [ + f.name for f in dataclasses.fields(scheduler_config) + ] _validate_chunked_prefill_settings_for_encoder_decoder( scheduler_config, is_encoder_decoder, expect_enabled) From 46594df3631243d323aa3fb2192f72215b073930 Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Mon, 29 Sep 2025 18:30:56 +0000 Subject: [PATCH 8/9] Disable prefix caching only if chunked prefill is explicitly disabled Signed-off-by: simondanielsson --- vllm/config/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index c909265c071d..1010747abd98 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -25,6 +25,7 @@ from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig +<<<<<<< HEAD from vllm.config.utils import (ConfigType, SupportsMetricsInfo, config, get_attr_docs, is_init_field, update_config) from vllm.config.vllm import (VllmConfig, get_cached_compilation_config, From 89eff01f2f1c55e133233e19eff28b5eac5205ad Mon Sep 17 00:00:00 2001 From: simondanielsson Date: Tue, 30 Sep 2025 08:35:12 +0200 Subject: [PATCH 9/9] Remove unused code Signed-off-by: simondanielsson --- vllm/config/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 1010747abd98..c909265c071d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -25,7 +25,6 @@ from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig -<<<<<<< HEAD from vllm.config.utils import (ConfigType, SupportsMetricsInfo, config, get_attr_docs, is_init_field, update_config) from vllm.config.vllm import (VllmConfig, get_cached_compilation_config,