diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 4b75eb19fcf9..67a978ad2aae 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1716,6 +1716,13 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: } +MODELS_NEED_VIDEO_METADATA = [ + "glm4_1v", + "glm4_5v", + "glm4_5v_fp8", +] + + def get_multi_modal_input(args): """ return { @@ -1740,12 +1747,13 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question + needs_metadata = args.model_type in MODELS_NEED_VIDEO_METADATA video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata vid_questions = ["Why is this video funny?"] return { - "data": [(video, metadata)] if args.model_type == "glm4_1v" else video, + "data": ([(video, metadata)] if needs_metadata else video), "questions": vid_questions, } diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index ced0ab3377a9..8bd93bd838fe 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -32,11 +32,14 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: # Ensure video metadata is included if "video" in mm_data: video = mm_data["video"] + num_frames = len(video) mm_data["video"] = (video, { - "total_num_frames": len(video), - "fps": len(video), + "total_num_frames": num_frames, + "fps": num_frames, "duration": 1, - "video_backend": "opencv" + "frames_indices": [i for i in range(num_frames)], + "video_backend": "opencv", + "do_sample_frames": True, }) return mm_data diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index dfb8d9b2a038..070ddcd89ee9 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -12,8 +12,19 @@ @pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"]) @pytest.mark.parametrize("expected_toks_per_frame", [299]) -@pytest.mark.parametrize("num_frames", [32, 128]) -@pytest.mark.parametrize("fps, expected_grid_t", [(1, 5), (2, 10)]) +@pytest.mark.parametrize( + "num_frames, fps, expected_grid_t", + [ + # pre-sampled fixed frames (unexpected behavior, + # but we still expect it to work without errors) + (32, 1, 16), + (32, 2, 16), + (128, 1, 64), + (128, 2, 64), + # post-sampled frames (expected behavior) + (-1, 1, 5), + (-1, 2, 10), + ]) def test_processor_override( model_id: str, expected_toks_per_frame: int, @@ -80,7 +91,7 @@ def test_video_loader_consistency( static_video, static_metadata = OpenCVVideoBackend.load_bytes(video_bytes) dynamic_video, dynamic_metadata = OpenCVDynamicVideoBackend.load_bytes( - video_bytes, requested_fps=fps) + video_bytes, fps=fps) # pre-sampled loader shouldn't read all frames assert len(dynamic_video) < len(static_video) diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 983e9114cccf..5c9e403c4b91 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -76,7 +76,7 @@ def video_to_pil_images_list(path: str, return [Image.fromarray(frame) for frame in frames] -def video_get_metadata(path: str) -> dict[str, Any]: +def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]: cap = cv2.VideoCapture(path) if not cap.isOpened(): raise ValueError(f"Could not open video file {path}") @@ -85,11 +85,18 @@ def video_get_metadata(path: str) -> dict[str, Any]: fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames / fps if fps > 0 else 0 + if num_frames == -1 or num_frames > total_frames: + num_frames = total_frames + metadata = { - "total_num_frames": total_frames, + "total_num_frames": num_frames, "fps": fps, "duration": duration, - "video_backend": "opencv" + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + # extra field used to control hf processor's video + # sampling behavior + "do_sample_frames": num_frames == total_frames, } return metadata @@ -126,7 +133,7 @@ def np_ndarrays(self) -> npt.NDArray: @property def metadata(self) -> dict[str, Any]: - ret = video_get_metadata(self.video_path) + ret = video_get_metadata(self.video_path, self.num_frames) return ret def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 22386a5e819a..c4e702625914 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -36,7 +36,9 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from packaging.version import Version from transformers import BatchFeature +from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( Glm4vImageProcessor, smart_resize) @@ -1001,28 +1003,32 @@ def _get_video_second_idx(self, metadata: dict[str, Any], max_frame_idx = meta_frames - 1 duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1) - if duration <= video_processor.max_duration: - n = int(math.floor(duration * video_processor.fps)) - frame_indices = [ - min( - max_frame_idx, - int(math.ceil(i * video_fps / video_processor.fps)), - ) for i in range(n) - ] + do_sample_frames = metadata["do_sample_frames"] + if not do_sample_frames: + frame_indices = metadata["frames_indices"] else: - num_samples = int(video_processor.max_duration * - video_processor.fps) - if num_samples >= meta_frames: - frame_indices = list(range(meta_frames)) - else: - target_seconds = np.linspace(0, - duration, - num_samples, - endpoint=True) + if duration <= video_processor.max_duration: + n = int(math.floor(duration * video_processor.fps)) frame_indices = [ - min(max_frame_idx, int(math.ceil(t * video_fps))) - for t in target_seconds + min( + max_frame_idx, + int(math.ceil(i * video_fps / video_processor.fps)), + ) for i in range(n) ] + else: + num_samples = int(video_processor.max_duration * + video_processor.fps) + if num_samples >= meta_frames: + frame_indices = list(range(meta_frames)) + else: + target_seconds = np.linspace(0, + duration, + num_samples, + endpoint=True) + frame_indices = [ + min(max_frame_idx, int(math.ceil(t * video_fps))) + for t in target_seconds + ] seen, uniq = set(), [] for idx in frame_indices: @@ -1139,7 +1145,9 @@ def _get_dummy_videos( "fps": 2.0, "duration": num_frames / 2.0, "total_num_frames": num_frames, + "frames_indices": [i for i in range(num_frames)], "video_backend": "opencv", + "do_sample_frames": False, } video_item = (video.copy(), video_metadata) video_items.append(video_item) @@ -1172,34 +1180,37 @@ def _call_hf_processor( for item in mm_data.pop("videos", []): video_array, metadata = item - if metadata["video_backend"] == "opencv_dynamic": - mm_kwargs["do_sample_frames"] = False - - elif metadata["total_num_frames"] != len(video_array): - logger.warning( - "Total frames in metadata " - "(%s) does not match the length of " - "video array %s. This can " - "be because the video is resampled " - "in advance. This may cause " - "a divergence with HF implementation.", - metadata["total_num_frames"], - len(video_array), - ) - metadata["total_num_frames"] = len(video_array) + # don't update mm_kwargs inplace + video_mm_kwargs = dict(**mm_kwargs) + video_mm_kwargs["do_sample_frames"] = metadata.get( + "do_sample_frames", True) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] - video_mm_data["video_metadata"] = [[VideoMetadata(**metadata)]] + + # backward compatibility for Transformers 4.55 + unuse_metadata = ["do_sample_frames"] + if not hasattr( + VideoMetadata, + "frames_indices") and "frames_indices" in metadata: + unuse_metadata.append("frames_indices") + + video_mm_data["video_metadata"] = [[ + VideoMetadata( + **{ + k: metadata[k] + for k in metadata if k not in unuse_metadata + }) + ]] video_outputs = super()._call_hf_processor( prompt="<|begin_of_video|><|video|><|end_of_video|>", mm_data=video_mm_data, - mm_kwargs=mm_kwargs, + mm_kwargs=video_mm_kwargs, tok_kwargs=tok_kwargs, ) - if "do_sample_frames" in mm_kwargs and not mm_kwargs[ - "do_sample_frames"]: + if not video_mm_kwargs["do_sample_frames"] and Version( + TRANSFORMERS_VERSION) < Version("4.56.0"): # Transformers v4.55 has incorrect timestamps issue for # skip sampling. We construct the placeholder manually to # get placeholders with correct timestamps. @@ -1218,6 +1229,7 @@ def _call_hf_processor( prompt = prompt.replace( "<|begin_of_video|><|video|><|end_of_video|>", video_placeholder, + 1, ) video_grid_thw_lst.append(video_outputs["video_grid_thw"]) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index df6e19da82ca..fb2dcac49ee9 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -121,14 +121,6 @@ def load_bytes( original_fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames_num / original_fps if original_fps > 0 else 0 - # Use transformers transformers.video_utils.VideoMetadata format - metadata = { - "total_num_frames": total_frames_num, - "fps": original_fps, - "duration": duration, - "video_backend": "opencv" - } - # resample video to target num_frames full_read = num_frames == -1 or total_frames_num < num_frames if full_read: @@ -159,6 +151,20 @@ def load_bytes( assert i == num_frames, (f"Expected reading {num_frames} frames, " f"but only loaded {i} frames from video.") + # Use transformers transformers.video_utils.VideoMetadata format + # NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata + # can cause incorrect timestamp calculation without num_frames=-1. + metadata = { + "total_num_frames": num_frames, + "fps": original_fps, + "duration": duration, + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + # extra field used to control hf processor's video + # sampling behavior + "do_sample_frames": num_frames == total_frames_num, + } + return frames, metadata @@ -170,7 +176,7 @@ def load_bytes( cls, data: bytes, num_frames: int = -1, - requested_fps: int = 2, + fps: int = 2, max_duration: int = 300, **kwargs, ) -> tuple[npt.NDArray, dict[str, Any]]: @@ -185,14 +191,6 @@ def load_bytes( original_fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames_num / original_fps if original_fps > 0 else 0 - # Use transformers transformers.video_utils.VideoMetadata format - metadata = { - "total_num_frames": total_frames_num, - "fps": original_fps, - "duration": duration, - "video_backend": "opencv_dynamic" - } - # resample video to target num_frames max_frame_idx = total_frames_num - 1 duration = duration or round(max_frame_idx / original_fps) + 1 @@ -201,14 +199,13 @@ def load_bytes( # https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140 frame_indices: Union[range, list[int]] if duration <= max_duration: - n = int(math.floor(duration * requested_fps)) + n = int(math.floor(duration * fps)) frame_indices = sorted({ - min(max_frame_idx, - int(math.ceil(i * original_fps / requested_fps))) + min(max_frame_idx, int(math.ceil(i * original_fps / fps))) for i in range(n) }) else: - num_samples = int(max_duration * requested_fps) + num_samples = int(max_duration * fps) if num_samples >= total_frames_num: frame_indices = range(total_frames_num) else: @@ -241,6 +238,16 @@ def load_bytes( f"Expected reading {len(frame_indices)} frames, " f"but only loaded {i} frames from video.") + # Use transformers transformers.video_utils.VideoMetadata format + metadata = { + "total_num_frames": total_frames_num, + "fps": original_fps, + "duration": duration, + "video_backend": "opencv_dynamic", + "frames_indices": list(frame_indices), + "do_sample_frames": False, + } + return frames, metadata