Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
999068b
Continuous Batching for VLMs
asmigosw Nov 5, 2025
1220cf9
Added CB support for InternVL
asmigosw Nov 10, 2025
c39ae01
Added CB support for Mistral3
asmigosw Nov 10, 2025
39f5c16
Updated test_image_text_to_text for CB tests
asmigosw Nov 11, 2025
9a42a08
Ruff format
asmigosw Nov 11, 2025
c1465c8
Added CB update for Molmo
asmigosw Nov 16, 2025
a6f1182
Added mistral CB support
asmigosw Nov 17, 2025
9e658bc
Merge branch 'main' into CB_VLM_update
asmigosw Nov 19, 2025
a6ee63f
Merge branch 'main' into CB_VLM_update
asmigosw Nov 20, 2025
94552e0
Added CB Test for InternVL
asmigosw Nov 20, 2025
e8af917
Ruff format
asmigosw Nov 20, 2025
f8d67e4
Merge branch 'main' into CB_VLM_update
asmigosw Nov 21, 2025
7ed78bc
Merge branch 'main' into CB_VLM_update
asmigosw Nov 25, 2025
eea2ffa
Resolving CI issues
asmigosw Nov 25, 2025
ee54215
Added InetrnVL example file for CB
asmigosw Nov 25, 2025
542d60f
Merge branch 'main' into CB_VLM_update
asmigosw Nov 26, 2025
77d07ea
Merge branch 'main' into CB_VLM_update
asmigosw Nov 26, 2025
b8b2299
Merge branch 'main' into CB_VLM_update
asmigosw Nov 27, 2025
9866d9b
Merge branch 'main' into CB_VLM_update
asmigosw Dec 2, 2025
453bd9e
Addressed Comments
asmigosw Dec 2, 2025
c2fe7ff
Comments Addressed
asmigosw Dec 2, 2025
e60bb46
Merge branch 'main' into CB_VLM_update
quic-mamta Dec 2, 2025
dc724c7
Merge branch 'main' into CB_VLM_update
quic-hemagnih Dec 3, 2025
dc5d8f5
Merge branch 'main' into CB_VLM_update
quic-hemagnih Dec 4, 2025
7202b67
Added CB test file
asmigosw Dec 4, 2025
7fb8a1d
Added llava_next CB support
asmigosw Dec 4, 2025
6c56af2
Added llava_next CB support
asmigosw Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 155 additions & 3 deletions QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
operations, separating them from the main text generation logic.
"""

from typing import Any, Dict, Optional, Tuple
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor
from transformers import AutoImageProcessor, AutoTokenizer

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import constants
from QEfficient.utils.logging_utils import logger


Expand All @@ -37,6 +39,9 @@ def __init__(
qeff_model: Optional[QAICInferenceSession],
vision_session: Optional[QAICInferenceSession],
processor: Optional[AutoImageProcessor],
tokenizer: Optional[AutoTokenizer],
image_height: Optional[int] = None,
image_width: Optional[int] = None,
config: Optional[Dict[str, Any]] = None,
lang_session: Optional[QAICInferenceSession] = None,
):
Expand All @@ -46,12 +51,18 @@ def __init__(
Args:
vision_session: QAICInferenceSession for vision model
processor: AutoImageProcessor for image preprocessing
tokenizer: AutoTokenizer for text tokenization
image_height: Desired image height for resizing
image_width: Desired image width for resizing
config: Configuration dictionary with vision model parameters
lang_session: Optional language session for coordination (to avoid resource conflicts)
"""
self._qeff_model = qeff_model
self._vision_session = vision_session
self._processor = processor
self._tokenizer = tokenizer
self._image_height = image_height
self._image_width = image_width
self._config = config or {}
self._lang_session = lang_session # Store language session for coordination

Expand All @@ -70,13 +81,132 @@ def is_available(self) -> bool:
"""
return self._vision_session is not None and self._processor is not None

def prepare_internVL_inputs(self, img_url: str, prompt: str) -> Dict[str, np.ndarray]:
"""
Prepare inputs for InternVL model

Args:
image_url: URL or path to image
prompt: Text query to process with image
"""
if not self._tokenizer:
raise ValueError("Tokenizer is required for InternVL input preparation")
pixel_values = []
num_patches_list = []
questions = []
img = requests.get(img_url, stream=True)
image = Image.open(BytesIO(img.content)).convert("RGB")

if self._image_height and self._image_width:
image = image.resize((self._image_height, self._image_width))
else:
logger.warning("Height and Width not specified. Using default image size for num_patches = 13.")
image = image.resize((constants.INTERN_IMAGE_HEIGHT, constants.INTERN_IMAGE_WIDTH))

# preprocess the resized image
pixel_value = self._processor.load_image(image, max_num=12)
num_patches_list.append(pixel_value.shape[0])
pixel_values.append(pixel_value)

question = "<image>\n" + prompt
questions.append(question)

pixel_values = torch.cat(pixel_values, dim=0)

# Chat Template information for prompt preprocessing
messages: List[List[str]] = []
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list)

inputs = self._tokenizer(prompt, return_tensors="pt")
inputs["pixel_values"] = pixel_values.clone()

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs

def prepare_molmo_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs
Args:
image_url: URL or path to image
query: Text query to process with image
Returns:
Dictionary of vision model inputs
Raises:
ValueError: If vision handler is not properly initialized
RuntimeError: If image processing fails
"""
if not self.is_available():
raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.")

try:
# Download image
if image_url.startswith(("http://", "https://")):
image = Image.open(requests.get(image_url, stream=True).raw)
else:
image = Image.open(image_url)
image = image.resize((constants.MOLMO_IMAGE_HEIGHT, constants.MOLMO_IMAGE_WIDTH))
inputs = self._processor.process(images=[image], text=query)
inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64)
valid = inputs["image_input_idx"] > 0
valid = valid.reshape(1, -1)
inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0)
inputs["pixel_values"] = inputs.pop("images")

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs
except Exception as e:
raise RuntimeError(f"Failed to process image {image_url}: {str(e)}")

def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs

Args:
image_url: URL or path to image
query: Text query to process with image
prefill_seq_len: Padded sequence length for language model

Returns:
Dictionary of vision model inputs
Expand All @@ -95,6 +225,17 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
else:
image = Image.open(image_url)

if self._image_height and self._image_width:
image = image.resize((self._image_width, self._image_height))
else:
logger.warning("Height and Width not specified. Using default image size.")
if "mistral3" in self._qeff_model.model.config.model_type:
image = image.resize((constants.MISTRAL3_IMAGE_HEIGHT, constants.MISTRAL3_IMAGE_WIDTH))
if "llava_next" in self._qeff_model.model.config.model_type:
image = image.resize(
(constants.GRANITEVISION_IMG_SIZE_HEIGHT, constants.GRANITEVISION_IMG_SIZE_WIDTH)
)

# Prepare conversation format
conversation = [
{
Expand Down Expand Up @@ -323,7 +464,18 @@ def get_processed_inputs(

try:
## Get vlm inputs ##
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)
if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "internvl_chat"
):
vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query)
elif (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "molmo"
):
vision_inputs, lang_inputs = self.prepare_molmo_inputs(image_url, query)
else:
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)

# Handle padding for language model
pad_token_id = 1
Expand Down
10 changes: 10 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
Expand All @@ -107,6 +109,8 @@ def __init__(
enable_debug_logs: Enable debug logging
write_io_dir: Directory for I/O file writing
full_batch_size: Enable continuous batching (new feature)
image_height: Desired image height for resizing
image_width: Desired image width for resizing
is_tlm: Target language model flag
include_sampler: Enable on-device sampling (new feature)
return_pdfs: Return probability distributions
Expand Down Expand Up @@ -143,6 +147,9 @@ def __init__(
)
self.qeff_model = qeff_model
self.processor = processor
self.tokenizer = tokenizer
self.image_height = image_height
self.image_width = image_width
self._vision_qpc_path = vision_qpc_path
self.device_id = device_id # Store device_id for vision components
self.enable_debug_logs = enable_debug_logs # Store for vision components
Expand Down Expand Up @@ -173,6 +180,9 @@ def _init_vision_components(self):
qeff_model=self.qeff_model,
vision_session=self._vision_session,
processor=self.processor,
tokenizer=self.tokenizer,
image_height=self.image_height,
image_width=self.image_width,
config=vision_config,
lang_session=self._session, # Pass language session for coordination
)
Expand Down
Loading
Loading