Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
fca0edb
calculate profiling size
wwl2755 Sep 24, 2025
f027a29
init
wwl2755 Sep 25, 2025
8b5f8c0
remove unused code
wwl2755 Sep 25, 2025
43472b9
remove unused code
wwl2755 Sep 25, 2025
0ecd273
remove unused code
wwl2755 Sep 25, 2025
f9edb4f
update logic
wwl2755 Sep 25, 2025
85acac5
fix lint
wwl2755 Sep 25, 2025
122efc9
fix lint
wwl2755 Sep 25, 2025
c2913ee
refactor
wwl2755 Sep 26, 2025
57f30c6
change default logic
wwl2755 Sep 26, 2025
d85c8e8
move mm_option to get_dummy_images
wwl2755 Sep 26, 2025
1462478
fix docstring
wwl2755 Sep 26, 2025
f2c4e9a
fix pre-commit
wwl2755 Sep 26, 2025
9c902b3
fix pre-commit
wwl2755 Sep 26, 2025
c93672c
add qwen2_audio and refactor limit_mm_per_prompt
wwl2755 Sep 27, 2025
7ca9c7f
pass only one modality each time and rename
wwl2755 Sep 27, 2025
9142c36
preserve compatibility for OOT models
wwl2755 Sep 27, 2025
2e0ee53
fix comments
wwl2755 Oct 1, 2025
37c1c49
fix mypy
wwl2755 Oct 1, 2025
5347ed9
fix doc-build
wwl2755 Oct 1, 2025
79a9068
fix type mistake
wwl2755 Oct 1, 2025
27e4792
Merge branch 'main' into pr/wwl2755/25631
hmellor Oct 1, 2025
6a6f0ab
Use pydantic to validate the new classes
hmellor Oct 1, 2025
d8e3872
Validate `limit_per_prompt` inside `MultiModalConfig`
hmellor Oct 1, 2025
ee66a9e
add all models
wwl2755 Oct 2, 2025
738e33d
Merge branch 'main' of github.com:wwl2755/vllm into mm-profiling
wwl2755 Oct 2, 2025
c3c79ca
Merge branch 'main' of github.com:wwl2755/vllm into mm-profiling
wwl2755 Oct 2, 2025
bd74d03
fix import error
wwl2755 Oct 2, 2025
8b0bea0
fix tests
wwl2755 Oct 2, 2025
79343c3
fix mllama4 test
wwl2755 Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions docs/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,17 +258,21 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)

target_width, target_height = \
self.info.get_image_size_with_most_features()

image_overrides = mm_options.get("image") if mm_options else None

return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
num_images=num_images,
overrides=image_overrides)
}
```

Expand Down Expand Up @@ -438,16 +442,20 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)

image_overrides = mm_options.get("image") if mm_options else None

return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
num_images=num_images,
overrides=image_overrides)
}
```

Expand Down
22 changes: 19 additions & 3 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from PIL import Image

from vllm.config import ModelConfig
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
ImageDummyOptions, VideoDummyOptions)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs
Expand Down Expand Up @@ -112,12 +114,26 @@ def _test_processing_correctness(

processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()
limit_mm_per_prompt = {
# Keep integer limits for local data generation
limit_mm_per_prompt_ints = {
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}

model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
if modality == "video":
return VideoDummyOptions(count=count)
if modality == "image":
return ImageDummyOptions(count=count)
if modality == "audio":
return AudioDummyOptions(count=count)
return BaseDummyOptions(count=count)

# Assign normalized DummyOptions to the model config
model_config.get_multimodal_config().limit_per_prompt = {
modality: _to_dummy_options(modality, count)
for modality, count in limit_mm_per_prompt_ints.items()
}

baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
Expand Down Expand Up @@ -150,7 +166,7 @@ def _test_processing_correctness(
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(limit + 1))]
for k, limit in limit_mm_per_prompt.items()
for k, limit in limit_mm_per_prompt_ints.items()
}

mm_counts = {k: len(vs) for k, vs in mm_data.items()}
Expand Down
10 changes: 5 additions & 5 deletions tests/models/multimodal/processing/test_mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@ def test_profiling(model_id: str, max_model_len: int):
model_config_kwargs = {
"max_model_len": max_model_len,
}
mm_counts = {"image": 1}
ctx = build_model_context(
model_id,
model_config_kwargs=model_config_kwargs,
limit_mm_per_prompt={"image": 1},
limit_mm_per_prompt=mm_counts,
)

mm_config = ctx.get_mm_config()
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor)

decoder_dummy_data = profiler.get_decoder_dummy_data(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
mm_counts=mm_counts,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
mm_counts=mm_counts,
)

hf_config = ctx.get_hf_config(Llama4Config)
Expand All @@ -58,7 +58,7 @@ def test_profiling(model_id: str, max_model_len: int):

profiled_tokens = profiler.get_mm_max_contiguous_tokens(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
mm_counts=mm_counts,
)

assert total_tokens == profiled_tokens["image"]
Expand Down
17 changes: 16 additions & 1 deletion tests/models/multimodal/processing/test_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from PIL import Image

from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
ImageDummyOptions, VideoDummyOptions)
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
Expand Down Expand Up @@ -236,7 +238,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt

def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
if modality == "video":
return VideoDummyOptions(count=count)
if modality == "image":
return ImageDummyOptions(count=count)
if modality == "audio":
return AudioDummyOptions(count=count)
return BaseDummyOptions(count=count)

model_config.get_multimodal_config().limit_per_prompt = {
modality: _to_dummy_options(modality, count)
for modality, count in limit_mm_per_prompt.items()
}
processor = factories.build_processor(ctx, cache=None)

with initialize_dummy_model(model_cls, model_config) as model:
Expand Down
4 changes: 3 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ class ModelConfig:
multimodal_config: Optional[MultiModalConfig] = None
"""Configuration for multimodal model. If `None`, this will be inferred
from the architecture of `self.model`."""
limit_mm_per_prompt: InitVar[Optional[dict[str, int]]] = None
limit_mm_per_prompt: InitVar[Optional[dict[str, Union[int,
dict[str,
int]]]]] = None
media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None
mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None
mm_processor_cache_gb: InitVar[Optional[float]] = None
Expand Down
95 changes: 83 additions & 12 deletions vllm/config/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,68 @@
import hashlib
from collections.abc import Mapping
from dataclasses import field
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union

from pydantic import ConfigDict, Field, field_validator
from pydantic.dataclasses import dataclass

import vllm.envs as envs
from vllm.config.utils import config


@dataclass
class BaseDummyOptions:
"""Base options for generating dummy data during profiling."""
count: int = Field(999, ge=0)


@dataclass(config=ConfigDict(extra="forbid"))
class VideoDummyOptions(BaseDummyOptions):
"""Options for generating dummy video data during profiling."""
num_frames: Optional[int] = Field(None, gt=0)
width: Optional[int] = Field(None, gt=0)
height: Optional[int] = Field(None, gt=0)


@dataclass(config=ConfigDict(extra="forbid"))
class ImageDummyOptions(BaseDummyOptions):
"""Options for generating dummy image data during profiling."""
width: Optional[int] = Field(None, gt=0)
height: Optional[int] = Field(None, gt=0)


@dataclass(config=ConfigDict(extra="forbid"))
class AudioDummyOptions(BaseDummyOptions):
"""Options for generating dummy audio data during profiling."""
length: Optional[int] = Field(None, gt=0)


MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"]
DummyOptions = Union[BaseDummyOptions, VideoDummyOptions, ImageDummyOptions,
AudioDummyOptions]


@config
@dataclass
class MultiModalConfig:
"""Controls the behavior of multimodal models."""

limit_per_prompt: dict[str, int] = field(default_factory=dict)
"""The maximum number of input items allowed per prompt for each modality.
Defaults to 1 (V0) or 999 (V1) for each modality.
limit_per_prompt: dict[str, DummyOptions] = field(default_factory=dict)
"""The maximum number of input items and options allowed per
prompt for each modality.
Defaults to 999 for each modality.

Legacy format (count only):
{"image": 16, "video": 2}

Configurable format (with options):
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
"image": {"count": 5, "width": 512, "height": 512}}

For example, to allow up to 16 images and 2 videos per prompt:
`{"image": 16, "video": 2}`"""
Mixed format (combining both):
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
"height": 512}}
"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
Expand Down Expand Up @@ -84,6 +124,27 @@ class MultiModalConfig:
from each video to be pruned.
"""

@field_validator("limit_per_prompt", mode="before")
@classmethod
def _validate_limit_per_prompt(
cls, value: dict[str, Union[int,
dict[str,
int]]]) -> dict[str, DummyOptions]:
for k, v in value.items():
# Handle legacy format where only count is specified
if isinstance(v, int):
v = {"count": v}
# Convert to the appropriate DummyOptions subclass
if k == "video":
value[k] = VideoDummyOptions(**v)
elif k == "image":
value[k] = ImageDummyOptions(**v)
elif k == "audio":
value[k] = AudioDummyOptions(**v)
else:
value[k] = BaseDummyOptions(**v)
return value

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand All @@ -106,12 +167,22 @@ def compute_hash(self) -> str:
def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality.
for the given modality (backward compatible).
"""
limit_data = self.limit_per_prompt.get(modality)

if limit_data is None:
# Unspecified modality is set to 999 by default
return 999
return limit_data.count

def get_dummy_options(self, modality: str) -> Optional[BaseDummyOptions]:
"""
Get the configurable dummy data options for a modality.
Returns None if no options are configured for this modality.
"""
return self.limit_per_prompt.get(
modality,
999 if envs.VLLM_USE_V1 else 1,
)
# All values are now DummyOptions after normalization
return self.limit_per_prompt.get(modality)

def merge_mm_processor_kwargs(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class EngineArgs:
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
limit_mm_per_prompt: dict[str, int] = \
limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = \
get_field(MultiModalConfig, "limit_per_prompt")
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str,
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers.models.aria.processing_aria import AriaProcessor

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -431,17 +432,21 @@ def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
vision_config = self.info.get_vision_config()

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

image_overrides = mm_options.get("image") if mm_options else None

return {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
num_images=num_images,
overrides=image_overrides)
}


Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_optimal_tiled_canvas)

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
Expand Down Expand Up @@ -166,16 +167,20 @@ def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
image_size = \
self.info.get_image_size_with_most_features()

image_overrides = mm_options.get("image") if mm_options else None

return {
"image":
self._get_dummy_images(width=image_size.width,
height=image_size.height,
num_images=num_images)
num_images=num_images,
overrides=image_overrides)
}


Expand Down
Loading