Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions tests/utils_/test_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,39 @@

from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
from vllm.model_executor.models.hyperclovax_vision import (
HCXVisionVideoPixelInputs)
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs


def test_tensor_schema_valid_tensor():
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32),
pixel_values=torch.randn(16, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
)


def test_tensor_schema_optional_fields():
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32),
pixel_values=torch.randn(16, 64, 3, 32, 32),
image_sizes=None,
)

Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), )
Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32))


def test_tensor_schema_constant_dim_failure():
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
image_sizes=torch.randint(0, 256, (16, 2)),
)


def test_tensor_schema_invalid_types_in_list():
with pytest.raises(ValueError, match="is not a torch.Tensor"):
with pytest.raises(TypeError, match="is not one of the expected types"):
Phi3VImagePixelInputs(
data=[
pixel_values=[
torch.randn(64, 3, 32, 32),
"not_a_tensor",
torch.randn(64, 3, 32, 32),
Expand All @@ -48,67 +50,75 @@ def test_tensor_schema_invalid_types_in_list():
def test_tensor_schema_rank_mismatch():
with pytest.raises(ValueError, match="has rank 3 but expected 5"):
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3),
pixel_values=torch.randn(16, 64, 3),
image_sizes=torch.randint(0, 256, (16, 2)),
)


def test_tensor_schema_missing_required_field():
with pytest.raises(ValueError, match="Required field 'data' is missing"):
with pytest.raises(ValueError,
match="Required field 'pixel_values' is missing"):
Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )


def test_tensor_schema_symbolic_dim_mismatch():
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
Phi3VImagePixelInputs(
data=torch.randn(12, 64, 3, 32, 32),
pixel_values=torch.randn(12, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
)


def test_tensor_schema_list_tensor_valid():
Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32) for _ in range(16)],
pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)],
image_sizes=torch.randint(0, 256, (16, 2)),
)


def test_tensor_schema_variable_patch_counts_valid():
# Each image has a different number of patches (p)
# Each tensor has shape (p, 3, 32, 32)
data = [
torch.randn(16, 3, 32, 32), # p = 16
torch.randn(32, 3, 32, 32), # p = 32
torch.randn(64, 3, 32, 32), # p = 64
]
image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3
Phi3VImagePixelInputs(
data=data,
image_sizes=image_sizes,
pixel_values=[
torch.randn(16, 3, 32, 32), # p = 16
torch.randn(32, 3, 32, 32), # p = 32
torch.randn(64, 3, 32, 32), # p = 64
],
image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3
)


def test_tensor_schema_tuple_tensor_valid():
Phi3VImagePixelInputs(
data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
image_sizes=torch.randint(0, 256, (16, 2)),
)


def test_tensor_schema_double_nested_tensors():
x = torch.rand(4, 3, 32, 32)
y = torch.rand(2, 3, 32, 32)

HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y]))


def test_tensor_schema_inconsistent_shapes_in_list():
with pytest.raises(ValueError, match="contains inconsistent shapes"):
Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32),
torch.randn(64, 3, 16, 16)] +
[torch.randn(64, 3, 32, 32) for _ in range(14)],
pixel_values=[
torch.randn(64, 3, 32, 32),
torch.randn(64, 3, 16, 16),
*(torch.randn(64, 3, 32, 32) for _ in range(14)),
],
image_sizes=torch.randint(0, 256, (16, 2)),
)


def test_tensor_schema_empty_list():
with pytest.raises(ValueError, match="is an empty list"):
with pytest.raises(ValueError, match="is an empty sequence"):
Phi3VImagePixelInputs(
data=[],
pixel_values=[],
image_sizes=torch.randint(0, 256, (0, 2)),
)

Expand All @@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check():
# This should NOT raise, because validation is turned off
# This would normally fail (dim[2] should be 3, not 4)
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32),
pixel_values=torch.randn(16, 64, 4, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)),
validate=False,
)


def test_tensor_schema_with_valid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
image_sizes = torch.randint(0, 256, (16, 2))

Phi3VImagePixelInputs(
data=data,
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={
"h": 336,
Expand All @@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims():


def test_tensor_schema_with_invalid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
image_sizes = torch.randint(0, 256, (16, 2))

# Should raise because 'h' and 'w' don't match resolve bindings
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
Phi3VImagePixelInputs(
data=data,
pixel_values=pixel_values,
image_sizes=image_sizes,
resolve_bindings={
"h": 336,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union, override
from typing import Annotated, Any, Callable, Literal, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -1170,7 +1170,7 @@ def _get_dummy_videos(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height, height)
height = min(height, override.height)
height = min(height, overrides.height)

video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
video_items = []
Expand Down
Loading