From 4da9d23bf87e232fc9f04b8abe756a10e8bd15c7 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 24 Sep 2025 16:08:29 +0800 Subject: [PATCH 01/11] init Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen2_vl.py | 6 ++-- vllm/model_executor/models/qwen3_vl.py | 39 ++++++++++++++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 2 +- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 472e8b061a9e..fd8940e3b5ed 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -972,10 +972,12 @@ def get_max_image_tokens(self) -> int: image_processor=None, ) - def _get_max_video_frames(self, max_tokens: int) -> int: + def _get_max_video_frames(self, + max_tokens: int, + start_num_frames: int = 1) -> int: target_width, target_height = self.get_image_size_with_most_features() - num_frames = 0 + num_frames = start_num_frames while True: next_num_frames = num_frames + 1 diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index ee6703f7229e..2b7c11a09ff7 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -33,11 +33,14 @@ import torch.nn.functional as F from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast -from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize +from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + smart_resize as image_smart_resize) from transformers.models.qwen3_vl import (Qwen3VLProcessor, Qwen3VLVideoProcessor) from transformers.models.qwen3_vl.configuration_qwen3_vl import ( Qwen3VLConfig, Qwen3VLVisionConfig) +from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( + smart_resize as video_smart_resize) from transformers.video_utils import VideoMetadata from vllm.attention.layer import check_upstream_fa_availability @@ -572,11 +575,16 @@ def _get_vision_info( image_height: int, num_frames: int = 2, do_resize: bool = True, - image_processor: Optional[Qwen2VLImageProcessorFast], + image_processor: Optional[Union[Qwen2VLImageProcessorFast, + Qwen3VLVideoProcessor]], ) -> tuple[ImageSize, int]: - if image_processor is None: + if image_processor is None and num_frames > 1: + image_processor = self.get_video_processor() + elif image_processor is None: image_processor = self.get_image_processor() + is_video = isinstance(image_processor, Qwen3VLVideoProcessor) + hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size @@ -584,12 +592,24 @@ def _get_vision_info( temporal_patch_size = vision_config.temporal_patch_size if do_resize: + if is_video: + smart_resize = video_smart_resize + extra_kwargs = { + "num_frames": num_frames, + "temporal_factor": temporal_patch_size + } + else: + smart_resize = image_smart_resize + extra_kwargs = {} resized_height, resized_width = smart_resize( + # num_frames=num_frames, height=image_height, width=image_width, + # temporal_factor=temporal_patch_size, factor=patch_size * merge_size, min_pixels=image_processor.size["shortest_edge"], max_pixels=image_processor.size["longest_edge"], + **extra_kwargs, ) preprocessed_size = ImageSize(width=resized_width, height=resized_height) @@ -608,6 +628,12 @@ def _get_vision_info( return preprocessed_size, num_vision_tokens + def _get_max_video_frames(self, + max_tokens: int, + start_num_frames: int = 2) -> int: + return super()._get_max_video_frames(max_tokens, + start_num_frames=start_num_frames) + def _calculate_timestamps(self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int): if not isinstance(indices, list): @@ -677,6 +703,9 @@ def get_dummy_mm_data( self.info.get_image_size_with_most_features()) target_num_frames = self.info.get_num_frames_with_most_features( seq_len, mm_counts) + print("seq_len:", seq_len) + print("target_num_frames, target_width, target_height:", + target_num_frames, target_width, target_height) return { "image": self._get_dummy_images(width=target_width, @@ -699,6 +728,7 @@ def _get_dummy_videos( num_frames: int, num_videos: int, ) -> list[VideoItem]: + print("num_frames, width, height:", num_frames, width, height) num_frames = max(num_frames, 2) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] @@ -783,6 +813,7 @@ def _call_hf_processor( video_grid_thw_lst.append(video_outputs["video_grid_thw"]) pixel_values_videos_lst.append( video_outputs["pixel_values_videos"]) + print("video_grid_thw_lst:", video_grid_thw_lst) video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), @@ -1292,9 +1323,11 @@ def get_multimodal_embeddings( multimodal_input = mm_input_by_modality[modality] if modality == "image": vision_embeddings = self._process_image_input(multimodal_input) + print("image", [x.shape for x in vision_embeddings]) multimodal_embeddings += vision_embeddings if modality == "video": video_embeddings = self._process_video_input(multimodal_input) + print("video", [x.shape for x in video_embeddings]) multimodal_embeddings += video_embeddings return multimodal_embeddings diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eebdbcc621c6..ee339e22cea9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2807,7 +2807,7 @@ def _get_mm_dummy_batch( dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_num_tokens, + seq_len=self.max_model_len, mm_counts={modality: 1}, cache=self.mm_budget.cache, ) From 43f24abfd8d611f894f568cb2f6477bb85fffbc1 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 24 Sep 2025 16:14:00 +0800 Subject: [PATCH 02/11] cleanup debug code Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen3_vl.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 2b7c11a09ff7..c66872e1861a 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -602,10 +602,8 @@ def _get_vision_info( smart_resize = image_smart_resize extra_kwargs = {} resized_height, resized_width = smart_resize( - # num_frames=num_frames, height=image_height, width=image_width, - # temporal_factor=temporal_patch_size, factor=patch_size * merge_size, min_pixels=image_processor.size["shortest_edge"], max_pixels=image_processor.size["longest_edge"], @@ -703,9 +701,6 @@ def get_dummy_mm_data( self.info.get_image_size_with_most_features()) target_num_frames = self.info.get_num_frames_with_most_features( seq_len, mm_counts) - print("seq_len:", seq_len) - print("target_num_frames, target_width, target_height:", - target_num_frames, target_width, target_height) return { "image": self._get_dummy_images(width=target_width, @@ -728,7 +723,6 @@ def _get_dummy_videos( num_frames: int, num_videos: int, ) -> list[VideoItem]: - print("num_frames, width, height:", num_frames, width, height) num_frames = max(num_frames, 2) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] @@ -813,7 +807,6 @@ def _call_hf_processor( video_grid_thw_lst.append(video_outputs["video_grid_thw"]) pixel_values_videos_lst.append( video_outputs["pixel_values_videos"]) - print("video_grid_thw_lst:", video_grid_thw_lst) video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), @@ -1323,11 +1316,9 @@ def get_multimodal_embeddings( multimodal_input = mm_input_by_modality[modality] if modality == "image": vision_embeddings = self._process_image_input(multimodal_input) - print("image", [x.shape for x in vision_embeddings]) multimodal_embeddings += vision_embeddings if modality == "video": video_embeddings = self._process_video_input(multimodal_input) - print("video", [x.shape for x in video_embeddings]) multimodal_embeddings += video_embeddings return multimodal_embeddings From 67a0dc9e22dc9f82fb32d7dae3e89f703b3fdf33 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 24 Sep 2025 16:15:23 +0800 Subject: [PATCH 03/11] revert gpu runner to avoid conflict Signed-off-by: Isotr0py --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ee339e22cea9..eebdbcc621c6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2807,7 +2807,7 @@ def _get_mm_dummy_batch( dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_model_len, + seq_len=self.max_num_tokens, mm_counts={modality: 1}, cache=self.mm_budget.cache, ) From 085bbd1c19baf208b771feb49c826a92fcd5d719 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 24 Sep 2025 16:18:58 +0800 Subject: [PATCH 04/11] hardcode num_frames=1 for image Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen2_vl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index fd8940e3b5ed..cbdcda77c9cc 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -935,6 +935,7 @@ def get_num_image_tokens( _, num_image_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, + num_frames=1, image_processor=image_processor, ) return num_image_tokens From 8a14ae1bdeb17c3f2301b737f93b56c94c9be2ba Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 24 Sep 2025 16:20:32 +0800 Subject: [PATCH 05/11] miss hardcode num_frames=1 Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen2_vl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index cbdcda77c9cc..0e1c4618efdc 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -960,6 +960,7 @@ def get_image_size_with_most_features(self) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=9999999, image_height=9999999, + num_frames=1, image_processor=None, ) return max_image_size From d21f7909c541ffc69b2dd1d636a45ccdbeb3cb24 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 24 Sep 2025 19:51:00 +0800 Subject: [PATCH 06/11] fix max frames per video Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen2_vl.py | 7 ++++--- vllm/model_executor/models/qwen3_vl.py | 21 +++++++++++++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0e1c4618efdc..96289a2c57f0 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -82,7 +82,7 @@ logger = init_logger(__name__) # For profile run -_MAX_FRAMES_PER_VIDEO = 600 +_MAX_FRAMES_PER_VIDEO = 14 # === Vision Inputs === # @@ -982,7 +982,7 @@ def _get_max_video_frames(self, num_frames = start_num_frames while True: - next_num_frames = num_frames + 1 + next_num_frames = num_frames + 2 next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, @@ -1001,12 +1001,13 @@ def get_num_frames_with_most_features( self, seq_len: int, mm_counts: Mapping[str, int], + max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO, ) -> int: max_videos = mm_counts.get("video", 0) max_total_frames = self._get_max_video_frames(seq_len) max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_frames_per_video) return max(max_frames_per_video, 1) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index c66872e1861a..dc5f3e5e158a 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -90,6 +90,9 @@ logger = init_logger(__name__) +# For profile run +_MAX_FRAMES_PER_VIDEO = 65536 + class Qwen3_VisionPatchEmbed(nn.Module): @@ -632,6 +635,14 @@ def _get_max_video_frames(self, return super()._get_max_video_frames(max_tokens, start_num_frames=start_num_frames) + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + return super().get_num_frames_with_most_features( + seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO) + def _calculate_timestamps(self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int): if not isinstance(indices, list): @@ -701,6 +712,12 @@ def get_dummy_mm_data( self.info.get_image_size_with_most_features()) target_num_frames = self.info.get_num_frames_with_most_features( seq_len, mm_counts) + target_video_size, _ = self.info._get_vision_info( + image_width=target_width, + image_height=target_height, + num_frames=target_num_frames, + image_processor=self.info.get_video_processor(), + ) return { "image": self._get_dummy_images(width=target_width, @@ -708,8 +725,8 @@ def get_dummy_mm_data( num_images=num_images), "video": self._get_dummy_videos( - width=target_width, - height=target_height, + width=target_video_size.width, + height=target_video_size.height, num_frames=target_num_frames, num_videos=num_videos, ), From 26d915bbb6c0deff120a8e09ba127605f26c67bb Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 24 Sep 2025 19:55:23 +0800 Subject: [PATCH 07/11] ooops Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen2_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 96289a2c57f0..94d1f7abb528 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -982,7 +982,7 @@ def _get_max_video_frames(self, num_frames = start_num_frames while True: - next_num_frames = num_frames + 2 + next_num_frames = num_frames + 1 next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, From 2a581656c2e59a3cc5a4189a756876d4f7bb3d31 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 27 Sep 2025 17:15:43 -0700 Subject: [PATCH 08/11] update Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 6768d1cb45e7..f8269b6a6802 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -88,8 +88,8 @@ logger = init_logger(__name__) -# For profile run -_MAX_FRAMES_PER_VIDEO = 65536 +# Official recommended max pixels is 24756 * 32 * 32 +_MAX_FRAMES_PER_VIDEO = 24756 class Qwen3_VisionPatchEmbed(nn.Module): @@ -664,6 +664,25 @@ def get_num_frames_with_most_features( return super().get_num_frames_with_most_features( seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO) + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + video_soft_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + # NOTE: By default in Qwen3-VL, one video token is converted to + # "" (on average 9 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 + formatted_video_soft_tokens = video_soft_tokens * 12 + return formatted_video_soft_tokens + def _calculate_timestamps(self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int): if not isinstance(indices, list): From c3f71b88fac0edc883904cbfd9f978a21a3e9460 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 27 Sep 2025 17:19:23 -0700 Subject: [PATCH 09/11] update Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f8269b6a6802..433d8e68ed2a 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -679,7 +679,7 @@ def get_max_video_tokens( ) # NOTE: By default in Qwen3-VL, one video token is converted to - # "" (on average 9 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 + # "<{timestamp} seconds>" (at most 9 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 formatted_video_soft_tokens = video_soft_tokens * 12 return formatted_video_soft_tokens From 241f3db7b368bb31d35365e10028c590b1fe4d91 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 27 Sep 2025 17:25:03 -0700 Subject: [PATCH 10/11] update estimation Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 433d8e68ed2a..8abfd2b2c14e 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -679,8 +679,8 @@ def get_max_video_tokens( ) # NOTE: By default in Qwen3-VL, one video token is converted to - # "<{timestamp} seconds>" (at most 9 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 - formatted_video_soft_tokens = video_soft_tokens * 12 + # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 + formatted_video_soft_tokens = video_soft_tokens * 12.5 return formatted_video_soft_tokens def _calculate_timestamps(self, indices: list[int] | torch.Tensor, From f2f8c427cdc1254ff235783009edcf8b8a47bf8c Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 27 Sep 2025 18:46:22 -0700 Subject: [PATCH 11/11] typo Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 8abfd2b2c14e..c8f91dd48969 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -88,8 +88,8 @@ logger = init_logger(__name__) -# Official recommended max pixels is 24756 * 32 * 32 -_MAX_FRAMES_PER_VIDEO = 24756 +# Official recommended max pixels is 24576 * 32 * 32 +_MAX_FRAMES_PER_VIDEO = 24576 class Qwen3_VisionPatchEmbed(nn.Module): @@ -681,7 +681,7 @@ def get_max_video_tokens( # NOTE: By default in Qwen3-VL, one video token is converted to # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 formatted_video_soft_tokens = video_soft_tokens * 12.5 - return formatted_video_soft_tokens + return int(formatted_video_soft_tokens) def _calculate_timestamps(self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int):