Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,13 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
}


MODELS_NEED_VIDEO_METADATA = [
"glm4_1v",
"glm4_5v",
"glm4_5v_fp8",
]


def get_multi_modal_input(args):
"""
return {
Expand All @@ -1740,12 +1747,13 @@ def get_multi_modal_input(args):

if args.modality == "video":
# Input video and question
needs_metadata = args.model_type in MODELS_NEED_VIDEO_METADATA
video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays
metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata
vid_questions = ["Why is this video funny?"]

return {
"data": [(video, metadata)] if args.model_type == "glm4_1v" else video,
"data": ([(video, metadata)] if needs_metadata else video),
"questions": vid_questions,
}

Expand Down
9 changes: 6 additions & 3 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
# Ensure video metadata is included
if "video" in mm_data:
video = mm_data["video"]
num_frames = len(video)
mm_data["video"] = (video, {
"total_num_frames": len(video),
"fps": len(video),
"total_num_frames": num_frames,
"fps": num_frames,
"duration": 1,
"video_backend": "opencv"
"frames_indices": [i for i in range(num_frames)],
"video_backend": "opencv",
"do_sample_frames": True,
})
return mm_data

Expand Down
17 changes: 14 additions & 3 deletions tests/models/multimodal/processing/test_glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,19 @@

@pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"])
@pytest.mark.parametrize("expected_toks_per_frame", [299])
@pytest.mark.parametrize("num_frames", [32, 128])
@pytest.mark.parametrize("fps, expected_grid_t", [(1, 5), (2, 10)])
@pytest.mark.parametrize(
"num_frames, fps, expected_grid_t",
[
# pre-sampled fixed frames (unexpected behavior,
# but we still expect it to work without errors)
(32, 1, 16),
(32, 2, 16),
(128, 1, 64),
(128, 2, 64),
# post-sampled frames (expected behavior)
(-1, 1, 5),
(-1, 2, 10),
])
def test_processor_override(
model_id: str,
expected_toks_per_frame: int,
Expand Down Expand Up @@ -80,7 +91,7 @@ def test_video_loader_consistency(

static_video, static_metadata = OpenCVVideoBackend.load_bytes(video_bytes)
dynamic_video, dynamic_metadata = OpenCVDynamicVideoBackend.load_bytes(
video_bytes, requested_fps=fps)
video_bytes, fps=fps)

# pre-sampled loader shouldn't read all frames
assert len(dynamic_video) < len(static_video)
Expand Down
15 changes: 11 additions & 4 deletions vllm/assets/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def video_to_pil_images_list(path: str,
return [Image.fromarray(frame) for frame in frames]


def video_get_metadata(path: str) -> dict[str, Any]:
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
Expand All @@ -85,11 +85,18 @@ def video_get_metadata(path: str) -> dict[str, Any]:
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0

if num_frames == -1 or num_frames > total_frames:
num_frames = total_frames

metadata = {
"total_num_frames": total_frames,
"total_num_frames": num_frames,
"fps": fps,
"duration": duration,
"video_backend": "opencv"
"video_backend": "opencv",
"frames_indices": list(range(num_frames)),
# extra field used to control hf processor's video
# sampling behavior
"do_sample_frames": num_frames == total_frames,
}
return metadata

Expand Down Expand Up @@ -126,7 +133,7 @@ def np_ndarrays(self) -> npt.NDArray:

@property
def metadata(self) -> dict[str, Any]:
ret = video_get_metadata(self.video_path)
ret = video_get_metadata(self.video_path, self.num_frames)
return ret

def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
Expand Down
88 changes: 50 additions & 38 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from packaging.version import Version
from transformers import BatchFeature
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor, smart_resize)
Expand Down Expand Up @@ -1001,28 +1003,32 @@ def _get_video_second_idx(self, metadata: dict[str, Any],
max_frame_idx = meta_frames - 1
duration = metadata.get("duration",
round(max_frame_idx / video_fps) + 1)
if duration <= video_processor.max_duration:
n = int(math.floor(duration * video_processor.fps))
frame_indices = [
min(
max_frame_idx,
int(math.ceil(i * video_fps / video_processor.fps)),
) for i in range(n)
]
do_sample_frames = metadata["do_sample_frames"]
if not do_sample_frames:
frame_indices = metadata["frames_indices"]
else:
num_samples = int(video_processor.max_duration *
video_processor.fps)
if num_samples >= meta_frames:
frame_indices = list(range(meta_frames))
else:
target_seconds = np.linspace(0,
duration,
num_samples,
endpoint=True)
if duration <= video_processor.max_duration:
n = int(math.floor(duration * video_processor.fps))
frame_indices = [
min(max_frame_idx, int(math.ceil(t * video_fps)))
for t in target_seconds
min(
max_frame_idx,
int(math.ceil(i * video_fps / video_processor.fps)),
) for i in range(n)
]
else:
num_samples = int(video_processor.max_duration *
video_processor.fps)
if num_samples >= meta_frames:
frame_indices = list(range(meta_frames))
else:
target_seconds = np.linspace(0,
duration,
num_samples,
endpoint=True)
frame_indices = [
min(max_frame_idx, int(math.ceil(t * video_fps)))
for t in target_seconds
]

seen, uniq = set(), []
for idx in frame_indices:
Expand Down Expand Up @@ -1139,7 +1145,9 @@ def _get_dummy_videos(
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": [i for i in range(num_frames)],
"video_backend": "opencv",
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
Expand Down Expand Up @@ -1172,34 +1180,37 @@ def _call_hf_processor(
for item in mm_data.pop("videos", []):
video_array, metadata = item

if metadata["video_backend"] == "opencv_dynamic":
mm_kwargs["do_sample_frames"] = False

elif metadata["total_num_frames"] != len(video_array):
logger.warning(
"Total frames in metadata "
"(%s) does not match the length of "
"video array %s. This can "
"be because the video is resampled "
"in advance. This may cause "
"a divergence with HF implementation.",
metadata["total_num_frames"],
len(video_array),
)
metadata["total_num_frames"] = len(video_array)
# don't update mm_kwargs inplace
video_mm_kwargs = dict(**mm_kwargs)
video_mm_kwargs["do_sample_frames"] = metadata.get(
"do_sample_frames", True)

video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[VideoMetadata(**metadata)]]

# backward compatibility for Transformers 4.55
unuse_metadata = ["do_sample_frames"]
if not hasattr(
VideoMetadata,
"frames_indices") and "frames_indices" in metadata:
unuse_metadata.append("frames_indices")

video_mm_data["video_metadata"] = [[
VideoMetadata(
**{
k: metadata[k]
for k in metadata if k not in unuse_metadata
})
]]

video_outputs = super()._call_hf_processor(
prompt="<|begin_of_video|><|video|><|end_of_video|>",
mm_data=video_mm_data,
mm_kwargs=mm_kwargs,
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "do_sample_frames" in mm_kwargs and not mm_kwargs[
"do_sample_frames"]:
if not video_mm_kwargs["do_sample_frames"] and Version(
TRANSFORMERS_VERSION) < Version("4.56.0"):
# Transformers v4.55 has incorrect timestamps issue for
# skip sampling. We construct the placeholder manually to
# get placeholders with correct timestamps.
Expand All @@ -1218,6 +1229,7 @@ def _call_hf_processor(
prompt = prompt.replace(
"<|begin_of_video|><|video|><|end_of_video|>",
video_placeholder,
1,
)

video_grid_thw_lst.append(video_outputs["video_grid_thw"])
Expand Down
49 changes: 28 additions & 21 deletions vllm/multimodal/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,6 @@ def load_bytes(
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0

# Use transformers transformers.video_utils.VideoMetadata format
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv"
}

# resample video to target num_frames
full_read = num_frames == -1 or total_frames_num < num_frames
if full_read:
Expand Down Expand Up @@ -159,6 +151,20 @@ def load_bytes(
assert i == num_frames, (f"Expected reading {num_frames} frames, "
f"but only loaded {i} frames from video.")

# Use transformers transformers.video_utils.VideoMetadata format
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
# can cause incorrect timestamp calculation without num_frames=-1.
metadata = {
"total_num_frames": num_frames,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv",
"frames_indices": list(range(num_frames)),
# extra field used to control hf processor's video
# sampling behavior
"do_sample_frames": num_frames == total_frames_num,
}

return frames, metadata


Expand All @@ -170,7 +176,7 @@ def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
requested_fps: int = 2,
fps: int = 2,
max_duration: int = 300,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
Expand All @@ -185,14 +191,6 @@ def load_bytes(
original_fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames_num / original_fps if original_fps > 0 else 0

# Use transformers transformers.video_utils.VideoMetadata format
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv_dynamic"
}

# resample video to target num_frames
max_frame_idx = total_frames_num - 1
duration = duration or round(max_frame_idx / original_fps) + 1
Expand All @@ -201,14 +199,13 @@ def load_bytes(
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
frame_indices: Union[range, list[int]]
if duration <= max_duration:
n = int(math.floor(duration * requested_fps))
n = int(math.floor(duration * fps))
frame_indices = sorted({
min(max_frame_idx,
int(math.ceil(i * original_fps / requested_fps)))
min(max_frame_idx, int(math.ceil(i * original_fps / fps)))
for i in range(n)
})
else:
num_samples = int(max_duration * requested_fps)
num_samples = int(max_duration * fps)
if num_samples >= total_frames_num:
frame_indices = range(total_frames_num)
else:
Expand Down Expand Up @@ -241,6 +238,16 @@ def load_bytes(
f"Expected reading {len(frame_indices)} frames, "
f"but only loaded {i} frames from video.")

# Use transformers transformers.video_utils.VideoMetadata format
metadata = {
"total_num_frames": total_frames_num,
"fps": original_fps,
"duration": duration,
"video_backend": "opencv_dynamic",
"frames_indices": list(frame_indices),
"do_sample_frames": False,
}

return frames, metadata


Expand Down