Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions tests/models/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@pytest.mark.parametrize(
("feature_sample_layers", "num_layers_loaded", "max_possible_layers",
("select_layers", "num_layers_loaded", "max_possible_layers",
"expected_features"),
[
# All layers loaded
Expand All @@ -28,8 +28,8 @@
([1, 10], 10, 20, [1, 10]),
([-20, -11], 10, 20, [1, 10]),
])
def test_resolve_visual_encoder_outputs(feature_sample_layers,
num_layers_loaded, max_possible_layers,
def test_resolve_visual_encoder_outputs(select_layers, num_layers_loaded,
max_possible_layers,
expected_features):
"""
Test that offsets are correctly handled for vision feature layers.
Expand All @@ -39,9 +39,10 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
]
output_tensor = resolve_visual_encoder_outputs(
encoder_outputs=encoder_outputs,
feature_sample_layers=feature_sample_layers,
post_layer_norm=None,
max_possible_layers=max_possible_layers)
select_layers=select_layers,
max_possible_layers=max_possible_layers,
)
assert torch.equal(torch.tensor(expected_features), output_tensor)


Expand Down
27 changes: 4 additions & 23 deletions vllm/model_executor/models/aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
Expand Down Expand Up @@ -350,29 +349,11 @@ def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
target_dtype: torch.dtype = \
vision_tower.get_input_embeddings().weight.dtype
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)

def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)

return json_map_leaves(select_features, image_features)

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")
return vision_tower(
pixel_values.to(dtype=vision_tower.dtype),
feature_select_strategy=self.config.vision_feature_select_strategy,
)

def _process_image_input(self, image_input: AyaVisionImagePixelInputs,
**kwargs) -> list[torch.Tensor]:
Expand Down
31 changes: 21 additions & 10 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsQuant

from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
resolve_visual_encoder_outputs)


class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
Expand Down Expand Up @@ -308,24 +309,29 @@ def __init__(
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
*,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:

hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)

return_all_hidden_states = feature_sample_layers is not None

# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
# depending on if we have select_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states)
return_all_hidden_states=select_layers is not None,
)

# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
encoder_outputs,
self.post_layernorm,
select_layers=select_layers,
max_possible_layers=self.config.num_hidden_layers,
feature_select_strategy=feature_select_strategy,
)

return encoder_outputs

Expand Down Expand Up @@ -355,9 +361,14 @@ def __init__(
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:
return self.vision_model(pixel_values, feature_sample_layers)
return self.vision_model(
pixel_values,
select_layers=select_layers,
feature_select_strategy=feature_select_strategy,
)

@property
def device(self):
Expand Down
25 changes: 4 additions & 21 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .clip import CLIPVisionModel
Expand Down Expand Up @@ -604,16 +603,6 @@ def _parse_and_validate_image_input(

raise AssertionError("This line should be unreachable.")

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
Expand All @@ -622,16 +611,10 @@ def _image_pixels_to_features(
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values)

def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)

return json_map_leaves(select_features, image_features)
return vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)

def _process_image_pixels(
self,
Expand Down
25 changes: 6 additions & 19 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
# Determine the layer up to which we will initialize the vision tower
if isinstance(vision_feature_layer, int):
vision_hidden_size = config.vision_config.hidden_size
self.feature_sample_layers = None
self.select_layers = None
# Used for multimodal granite models to control encoder outputs
elif isinstance(vision_feature_layer, (list, tuple)):
vision_hidden_size = config.vision_config.hidden_size * len(
vision_feature_layer)
self.feature_sample_layers = vision_feature_layer
self.select_layers = vision_feature_layer
else:
raise TypeError(
f"vision_layer_feature type: {type(vision_feature_layer)}"
Expand Down Expand Up @@ -312,30 +312,17 @@ def _parse_and_validate_image_input(

raise AssertionError("This line should be unreachable.")

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(
pixel_values, feature_sample_layers=self.feature_sample_layers)

return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
return vision_tower(
pixel_values,
select_layers=self.select_layers,
feature_select_strategy=self.config.vision_feature_select_strategy,
)

# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
Expand Down
17 changes: 3 additions & 14 deletions vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,27 +349,16 @@ def _parse_and_validate_video_input(
"w": expected_w,
})

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features = self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
image_features = vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
image_features = self.vision_resampler(image_features)
image_features = self.multi_modal_projector(image_features)
Expand Down
25 changes: 6 additions & 19 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,27 +577,16 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

return mm_input_by_modality

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
return vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)

# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
Expand Down Expand Up @@ -750,13 +739,11 @@ def _video_pixels_to_features(
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
video_features = vision_tower(pixel_values)
video_features = self._select_image_features(
video_features,
strategy=self.config.vision_feature_select_strategy,
video_features = vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
video_features = self.multi_modal_projector(video_features)
video_features = self.apply_pooling(video_features)
Expand Down
24 changes: 4 additions & 20 deletions vllm/model_executor/models/minimax_vl_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .clip import CLIPVisionModel
Expand Down Expand Up @@ -221,15 +220,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
def get_language_model(self) -> torch.nn.Module:
return self.language_model

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
Expand All @@ -238,16 +228,10 @@ def _image_pixels_to_features(
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features: tuple[torch.Tensor, ...] = \
tuple(vision_tower(p) for p in pixel_values)

def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)

return json_map_leaves(select_features, image_features)
feature_select_strategy = self.config.vision_feature_select_strategy
return tuple(
vision_tower(p, feature_select_strategy=feature_select_strategy)
for p in pixel_values)

# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
def pack_image_features(self, image_features: list[torch.Tensor],
Expand Down
Loading