diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index 76da7afc2..e07b5dd04 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -12,15 +12,17 @@ operations, separating them from the main text generation logic. """ -from typing import Any, Dict, Optional, Tuple +from io import BytesIO +from typing import Any, Dict, List, Optional, Tuple import numpy as np import requests import torch from PIL import Image -from transformers import AutoImageProcessor +from transformers import AutoImageProcessor, AutoTokenizer from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants from QEfficient.utils.logging_utils import logger @@ -37,6 +39,9 @@ def __init__( qeff_model: Optional[QAICInferenceSession], vision_session: Optional[QAICInferenceSession], processor: Optional[AutoImageProcessor], + tokenizer: Optional[AutoTokenizer], + image_height: Optional[int] = None, + image_width: Optional[int] = None, config: Optional[Dict[str, Any]] = None, lang_session: Optional[QAICInferenceSession] = None, ): @@ -46,12 +51,18 @@ def __init__( Args: vision_session: QAICInferenceSession for vision model processor: AutoImageProcessor for image preprocessing + tokenizer: AutoTokenizer for text tokenization + image_height: Desired image height for resizing + image_width: Desired image width for resizing config: Configuration dictionary with vision model parameters lang_session: Optional language session for coordination (to avoid resource conflicts) """ self._qeff_model = qeff_model self._vision_session = vision_session self._processor = processor + self._tokenizer = tokenizer + self._image_height = image_height + self._image_width = image_width self._config = config or {} self._lang_session = lang_session # Store language session for coordination @@ -70,6 +81,124 @@ def is_available(self) -> bool: """ return self._vision_session is not None and self._processor is not None + def prepare_internVL_inputs(self, img_url: str, prompt: str) -> Dict[str, np.ndarray]: + """ + Prepare inputs for InternVL model + + Args: + image_url: URL or path to image + prompt: Text query to process with image + """ + if not self._tokenizer: + raise ValueError("Tokenizer is required for InternVL input preparation") + pixel_values = [] + num_patches_list = [] + questions = [] + img = requests.get(img_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + + if self._image_height and self._image_width: + image = image.resize((self._image_height, self._image_width)) + else: + logger.warning("Height and Width not specified. Using default image size for num_patches = 13.") + image = image.resize((constants.INTERN_IMAGE_HEIGHT, constants.INTERN_IMAGE_WIDTH)) + + # preprocess the resized image + pixel_value = self._processor.load_image(image, max_num=12) + num_patches_list.append(pixel_value.shape[0]) + pixel_values.append(pixel_value) + + question = "\n" + prompt + questions.append(question) + + pixel_values = torch.cat(pixel_values, dim=0) + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list) + + inputs = self._tokenizer(prompt, return_tensors="pt") + inputs["pixel_values"] = pixel_values.clone() + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs + + def prepare_molmo_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]: + """ + Download and preprocess image into model inputs + Args: + image_url: URL or path to image + query: Text query to process with image + Returns: + Dictionary of vision model inputs + Raises: + ValueError: If vision handler is not properly initialized + RuntimeError: If image processing fails + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.") + + try: + # Download image + if image_url.startswith(("http://", "https://")): + image = Image.open(requests.get(image_url, stream=True).raw) + else: + image = Image.open(image_url) + image = image.resize((constants.MOLMO_IMAGE_HEIGHT, constants.MOLMO_IMAGE_WIDTH)) + inputs = self._processor.process(images=[image], text=query) + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + valid = inputs["image_input_idx"] > 0 + valid = valid.reshape(1, -1) + inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) + inputs["pixel_values"] = inputs.pop("images") + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs + except Exception as e: + raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") + def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]: """ Download and preprocess image into model inputs @@ -77,6 +206,7 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - Args: image_url: URL or path to image query: Text query to process with image + prefill_seq_len: Padded sequence length for language model Returns: Dictionary of vision model inputs @@ -95,6 +225,17 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - else: image = Image.open(image_url) + if self._image_height and self._image_width: + image = image.resize((self._image_width, self._image_height)) + else: + logger.warning("Height and Width not specified. Using default image size.") + if "mistral3" in self._qeff_model.model.config.model_type: + image = image.resize((constants.MISTRAL3_IMAGE_HEIGHT, constants.MISTRAL3_IMAGE_WIDTH)) + if "llava_next" in self._qeff_model.model.config.model_type: + image = image.resize( + (constants.GRANITEVISION_IMG_SIZE_HEIGHT, constants.GRANITEVISION_IMG_SIZE_WIDTH) + ) + # Prepare conversation format conversation = [ { @@ -323,7 +464,18 @@ def get_processed_inputs( try: ## Get vlm inputs ## - vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "internvl_chat" + ): + vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query) + elif ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "molmo" + ): + vision_inputs, lang_inputs = self.prepare_molmo_inputs(image_url, query) + else: + vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) # Handle padding for language model pad_token_id = 1 diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 5eb91d142..b37fdc74a 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -88,6 +88,8 @@ def __init__( enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, full_batch_size: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, @@ -107,6 +109,8 @@ def __init__( enable_debug_logs: Enable debug logging write_io_dir: Directory for I/O file writing full_batch_size: Enable continuous batching (new feature) + image_height: Desired image height for resizing + image_width: Desired image width for resizing is_tlm: Target language model flag include_sampler: Enable on-device sampling (new feature) return_pdfs: Return probability distributions @@ -143,6 +147,9 @@ def __init__( ) self.qeff_model = qeff_model self.processor = processor + self.tokenizer = tokenizer + self.image_height = image_height + self.image_width = image_width self._vision_qpc_path = vision_qpc_path self.device_id = device_id # Store device_id for vision components self.enable_debug_logs = enable_debug_logs # Store for vision components @@ -173,6 +180,9 @@ def _init_vision_components(self): qeff_model=self.qeff_model, vision_session=self._vision_session, processor=self.processor, + tokenizer=self.tokenizer, + image_height=self.image_height, + image_width=self.image_width, config=vision_config, lang_session=self._session, # Pass language session for coordination ) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 398259d8b..c91d2fe32 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -610,6 +610,7 @@ def forward( image_idx, past_key_values, comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -625,6 +626,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) @@ -684,6 +686,9 @@ def get_specializations( comp_ctx_lengths_prefill: Optional[List[int]] = None, comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 @@ -707,50 +712,72 @@ def get_specializations( lang = [] for i in range(0, len(comp_ctx_lengths_prefill)): - lang.append( - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_prefill[i], - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - } - ) - - for i in range(0, len(comp_ctx_lengths_decode)): - lang.append( - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_decode[i], - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - } - ) - - else: - lang = [ - { - "batch_size": batch_size, + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, "mm_tokens_per_image": mm_tokens_per_image, - }, - { - "batch_size": batch_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, "mm_tokens_per_image": mm_tokens_per_image, - }, - ] + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang = [lang_prefill, lang_decode] specializations = {} @@ -761,17 +788,21 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"} + lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "mm_tokens_per_image"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} - pkv_dynamic_sliding_axes = {0: "batch_size", 2: "sliding_window"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} + pkv_dynamic_sliding_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"} layer_switch = ( self.language_model.config.sliding_window_pattern if hasattr(self.language_model.config, "sliding_window_pattern") @@ -837,7 +868,9 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_dummy_inputs( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: @@ -876,15 +909,21 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV lang_inputs["past_key_values"] = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 96c59325f..85c331aa8 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -44,6 +44,7 @@ def forward( image_idx, past_key_values, comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, ): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape @@ -69,6 +70,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) @@ -91,6 +93,9 @@ def get_specializations( comp_ctx_lengths_prefill: Optional[List[int]] = None, comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): num_patches = compiler_options.pop("num_patches", None) @@ -124,50 +129,71 @@ def get_specializations( lang = [] for i in range(0, len(comp_ctx_lengths_prefill)): - lang.append( - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_prefill[i], - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - } - ) - - for i in range(0, len(comp_ctx_lengths_decode)): - lang.append( - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_decode[i], - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - } - ) - - else: - lang = [ - { - "batch_size": batch_size, + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "num_patches": num_patches, "img_size": img_size, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], "num_patches": num_patches, "img_size": img_size, "vision_size": vision_size, - }, - ] + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] specializations = {} @@ -176,18 +202,24 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {1: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "batched_num_patches", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes @@ -222,7 +254,9 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_dummy_inputs( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -271,10 +305,13 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl ) lang_inputs["image_idx"] = torch.zeros((1, 1), dtype=torch.int64) + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -285,6 +322,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 0bcdf8ae0..7a2f687fe 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -1065,9 +1065,7 @@ def get_specializations( else: lang_decode["batch_size"] = kv_cache_batch_size - lang = [] - lang.append(lang_prefill) - lang.append(lang_decode) + lang = [lang_prefill, lang_decode] specializations = {} diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index dc6653db0..d5f5ee920 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -18,6 +18,7 @@ from QEfficient.utils.logging_utils import logger BS = 1 +FBS = 4 NUM_CHANNEL = 3 SEQ_LEN = 592 CTX_LEN = 1024 @@ -61,6 +62,7 @@ def forward( image_idx, past_key_values, comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) @@ -76,6 +78,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, return_dict=True, ) @@ -140,7 +143,13 @@ def forward( image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -165,8 +174,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl for i in range(num_layers): lang_inputs["past_key_values"].append( ( - torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim), - torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim), + torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim), + torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim), ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) @@ -174,6 +183,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(BS).view(BS, 1) inputs = {} if kv_offload: @@ -193,6 +204,9 @@ def get_specializations( comp_ctx_lengths_prefill: Optional[List[int]] = None, comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_images = compiler_options.pop("max_num_images", 1) @@ -218,49 +232,72 @@ def get_specializations( lang = [] for i in range(0, len(comp_ctx_lengths_prefill)): - lang.append( - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_prefill[i], - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - } - ) - - for i in range(0, len(comp_ctx_lengths_decode)): - lang.append( - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_decode[i], - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - } - ) - else: - lang = [ - { - "batch_size": batch_size, + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "max_num_images": max_num_images, "img_size": img_size, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], "max_num_images": max_num_images, "img_size": img_size, "vision_size": vision_size, - }, - ] + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] specializations = {} @@ -269,9 +306,13 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -281,11 +322,19 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } if comp_ctx_lengths is not None: lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 2e4848b6b..878d04a45 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -20,6 +20,9 @@ from QEfficient.utils._utils import IOInfo from QEfficient.utils.logging_utils import logger +BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE +FBS = constants.ONNX_EXPORT_EXAMPLE_FBS + class QEffLlavaNextEncoderWrapper(nn.Module): def __init__(self, model): @@ -133,6 +136,7 @@ def forward( image_idx, past_key_values, comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) @@ -149,6 +153,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -165,7 +170,13 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlavaNextDecoderWrapper(self) - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -214,13 +225,13 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl lang_inputs["past_key_values"].append( ( torch.zeros( - constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + FBS if continuous_batching else BS, num_key_value_heads, constants.GRANITEVISION_CTX_LEN, head_dim, ), torch.zeros( - constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + FBS if continuous_batching else BS, num_key_value_heads, constants.GRANITEVISION_CTX_LEN, head_dim, @@ -232,6 +243,9 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(BS).view(BS, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -250,6 +264,9 @@ def get_specializations( comp_ctx_lengths_prefill: Optional[List[int]] = None, comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_images = compiler_options.pop("max_num_images", 1) @@ -306,62 +323,85 @@ def get_specializations( lang = [] for i in range(0, len(comp_ctx_lengths_prefill)): - lang.append( - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_prefill[i], - "image_size_height": image_size_height, - "image_size_width": image_size_width, - "num_patches": num_patches, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - } - ) - - # Remaining elements use comp_ctx_lengths[1:] in a loop - for i in range(0, len(comp_ctx_lengths_decode)): - lang.append( - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_decode[i], - "image_size_height": image_size_height, - "image_size_width": image_size_width, - "num_patches": num_patches, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - } - ) - else: - lang = [ - { - "batch_size": batch_size, + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "image_size_height": image_size_height, "image_size_width": image_size_width, "num_patches": num_patches, "max_num_images": max_num_images, "img_size": img_size, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], "image_size_height": image_size_height, "image_size_width": image_size_width, "num_patches": num_patches, "max_num_images": max_num_images, "img_size": img_size, "vision_size": vision_size, - }, - ] + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] specializations = {} if kv_offload: @@ -369,9 +409,13 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { @@ -381,11 +425,19 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } if comp_ctx_lengths is not None: lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 694ed4cde..89e19c65b 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -176,20 +176,22 @@ def forward( image_idx, past_key_values, comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, ): - inputs_embeds = self.model.get_input_embeddings()(input_ids) - vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) mask = input_ids == self.model.config.image_token_index indices1 = mask.to(torch.int64).cumsum(1) - 1 indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(mask.shape[0]).view(-1, 1) image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] - inputs_embeds_1 = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) - outputs = self.model.model( - inputs_embeds=inputs_embeds_1, + image_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) + outputs = self.language_model( + inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, ) # Cast to int32 to avoid ONNXRT issue @@ -250,7 +252,13 @@ def forward( return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size @@ -290,10 +298,14 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( - config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -304,6 +316,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: @@ -324,6 +338,9 @@ def get_specializations( comp_ctx_lengths_prefill: Optional[List[int]] = None, comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): if img_size is None and hasattr(self.config.vision_config, "image_size"): @@ -352,46 +369,66 @@ def get_specializations( lang = [] for i in range(0, len(comp_ctx_lengths_prefill)): - lang.append( - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_prefill[i], - "image_size": img_size, - "vision_size": vision_size, - } - ) - - # Remaining elements use comp_ctx_lengths[1:] in a loop - for i in range(0, len(comp_ctx_lengths_decode)): - lang.append( - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "comp_ctx_lengths": comp_ctx_lengths_decode[i], - "image_size": img_size, - "vision_size": vision_size, - } - ) - else: - lang = [ - { - "batch_size": batch_size, + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "image_size": img_size, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], "image_size": img_size, "vision_size": vision_size, - }, - ] + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang = [lang_prefill, lang_decode] specializations = {} @@ -404,7 +441,9 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -417,9 +456,18 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv "vision_embeds": {0: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } if comp_ctx_lengths is not None: lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f3618cb1e..91866e4c0 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1284,6 +1284,8 @@ def generate( device_ids: List[int] = None, runtime_ai100: bool = True, generation_len: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, ) -> Union[torch.Tensor, np.ndarray]: """ Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. @@ -1342,6 +1344,8 @@ def generate( full_batch_size=fbs, comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + image_height=image_height, + image_width=image_width, ) # Call generate method @@ -2493,6 +2497,7 @@ def from_pretrained( kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + continuous_batching=continuous_batching, **kwargs, ) return cls( diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index c088158c4..7bfa58fc0 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -43,14 +43,14 @@ def eager_attention_forward( if num_q_heads != num_kv_heads: assert num_q_heads % num_kv_heads == 0 repeat_factor = num_q_heads // num_kv_heads - _, _, S, D = k.shape + B, _, S, D = k.shape k = k.unsqueeze(2) k = k.expand(-1, -1, repeat_factor, -1, -1) - k = k.reshape(1, num_q_heads, S, D) + k = k.reshape(B, num_q_heads, S, D) v = v.unsqueeze(2) v = v.expand(-1, -1, repeat_factor, -1, -1) - v = v.reshape(1, num_q_heads, S, D) + v = v.reshape(B, num_q_heads, S, D) attn_weights = torch.matmul(q, k.transpose(2, 3)) * scale_factor @@ -596,6 +596,7 @@ def forward( image_idx, past_key_values, comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, ): if input_ids is not None: input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) @@ -613,6 +614,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) @@ -694,6 +696,9 @@ def get_specializations( comp_ctx_lengths_decode: Optional[List[int]] = None, valid_size: int = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): prefill_seq_len = prefill_seq_len if prefill_seq_len else 1024 @@ -725,12 +730,20 @@ def get_specializations( for i in range(0, len(comp_ctx_lengths_prefill)): lang_prefill = { - "batch_size": batch_size, + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "valid_size": valid_size, + "vision_batch_size": batch_size, } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size if kv_offload: values = { "img_size": img_size, @@ -746,12 +759,17 @@ def get_specializations( for i in range(0, len(comp_ctx_lengths_decode)): lang_decode = { - "batch_size": batch_size, + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, "comp_ctx_lengths": comp_ctx_lengths_decode[i], "valid_size": valid_size, + "vision_batch_size": batch_size, } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size if kv_offload: values = { "img_size": img_size, @@ -767,13 +785,33 @@ def get_specializations( else: lang_prefill = { - "batch_size": batch_size, + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "valid_size": valid_size, + "vision_batch_size": batch_size, } - lang_decode = {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "valid_size": valid_size} + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "valid_size": valid_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size if kv_offload: values = { @@ -787,9 +825,7 @@ def get_specializations( lang_prefill[key] = value lang_decode[key] = value - lang = [] - lang.append(lang_prefill) - lang.append(lang_decode) + lang = [lang_prefill, lang_decode] specializations = {} @@ -800,13 +836,15 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "valid_size"} + lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "valid_size"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 1: "num_images", 2: "img_tile", 3: "img_size"} vision_dynamic_axes["image_input_idx"] = {0: "batch_size", 1: "num_images", 2: "num_patch"} @@ -816,8 +854,17 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv num_layers = self.model.config.n_layers for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} if comp_ctx_lengths is not None: lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} @@ -851,7 +898,13 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): inputs_shapes = {} inputs_shapes_lang = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -902,10 +955,14 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -916,6 +973,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 33a434db1..63e046600 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1169,9 +1169,7 @@ def smart_resize( else: lang_decode["batch_size"] = kv_cache_batch_size - lang = [] - lang.append(lang_prefill) - lang.append(lang_decode) + lang = [lang_prefill, lang_decode] specializations = {} diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 3752db40c..e0b003422 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -100,6 +100,8 @@ def get_models_dir(): INTERN_CTX_LEN = 4096 INTERN_PREFILL_SEQ_LEN = INTERN_CTX_LEN - 256 # 4096-256 INTERN_NUM_CHANNELS = 3 +INTERN_IMAGE_HEIGHT = 1000 +INTERN_IMAGE_WIDTH = 747 INTERN_IMG_CONTEXT_TOKEN = 151667 # Specific to InternVL3_5 series, same token won't work for InternVL2_5 series @@ -135,6 +137,14 @@ def get_models_dir(): # Modules to cache while clearing the pytorch weights CACHE_MODULES = ["get_output_names", "get_dummy_inputs", "get_onnx_dynamic_axes", "get_specializations"] +# Mistral3 Constants +MISTRAL3_IMAGE_HEIGHT = 1540 +MISTRAL3_IMAGE_WIDTH = 1540 + +# Molmo Constants +MOLMO_IMAGE_HEIGHT = 536 +MOLMO_IMAGE_WIDTH = 354 + class Constants: # Export Constants. diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index c54dadeac..61553e7ea 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import os +from typing import List import numpy as np import onnx @@ -276,6 +277,54 @@ def __init__( self.config = config self.gen_len = max_gen_len + @torch.no_grad() + def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries): + """ + Function responsible for running HuggingFace ``PyTorch`` model for continuous batching + and return the output tokens for each prompt/image pair. + + ``Mandatory`` Args: + :model (torch.nn.module): Original ``PyTorch`` model + :images (List[PIL.Image]): List of input images + :queries (List[str]): List of input queries + + Return: + :List[numpy.ndarray]: List of generated output tokens for each prompt + """ + generated_ids = [] + + for idx, (image, query) in enumerate(zip(images, queries)): + # Prepare conversation format for each image-query pair + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + + # Process inputs + inputs = self.processor(images=image, text=prompt, return_tensors="pt") + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Generate tokens + output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False) + offset_output = output[0, inputs["input_ids"].shape[1] :] + + # Decode and print output + py_output = self.processor.tokenizer.decode(offset_output).strip() + print(f"Original HF Model Outputs (Torch CPU) for prompt {idx}:") + print("Query:", repr(query)) + print("Completion:", repr(py_output)) + + generated_ids.append(offset_output.numpy()) + + return generated_ids + @torch.no_grad() def run_vlm_hf_model_on_pytorch(self, model, inputs): output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False) @@ -448,6 +497,57 @@ def __init__(self, batch_size, processor, config, image, prompt, prompt_len, ctx self.config = config self.gen_len = max_gen_len + @torch.no_grad() + def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries): + """ + Function responsible for running HuggingFace ``PyTorch`` model for continuous batching + and return the output tokens for each prompt/image pair. + + ``Mandatory`` Args: + :model (torch.nn.module): Original ``PyTorch`` model + :images (List[PIL.Image]): List of input images + :queries (List[str]): List of input queries + + Return: + :List[numpy.ndarray]: List of generated output tokens for each prompt + """ + generated_ids = [] + + for idx, (image, query) in enumerate(zip(images, queries)): + num_patches_list = [] + pixel_values = [] + questions = [] + + pixel_value = self.processor.load_image(image, max_num=12) + num_patches_list.append(pixel_value.shape[0]) + question = "\n" + query + + pixel_values.append(pixel_value) + pixel_values = torch.cat(pixel_values, dim=0) + questions.append(question) + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = self.processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list) + + inputs = self.processor.tokenizer(prompt, return_tensors="pt") + inputs["pixel_values"] = pixel_values.clone() + + generation_config = dict(max_new_tokens=self.gen_len, do_sample=False) + generation_config["eos_token_id"] = self.processor.tokenizer.convert_tokens_to_ids("<|im_end|>\n".strip()) + + # Decode and print output + outputs = model.generate(**inputs, **generation_config) + offset_output = outputs[0].detach().numpy() + + py_output = self.processor.tokenizer.decode(offset_output, skip_special_tokens=True).strip() + print(f"Original HF Model Outputs (Torch CPU) for prompt {idx}:") + print("Completion:", repr(py_output)) + generated_ids.append(offset_output) + + return generated_ids + @torch.no_grad() def run_vlm_hf_model_on_pytorch(self, model, inputs, generation_config): outputs = model.generate(**inputs, **generation_config) @@ -490,3 +590,34 @@ def run_vlm_hf_model_on_pytorch(self, model, inputs, generation_config): print("Original HF Model Outputs (Torch CPU):") print("Completion:", repr(py_output)) return generated_ids + + @torch.no_grad() + def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries, generation_config): + """ + Function responsible for running HuggingFace ``PyTorch`` model for continuous batching + and return the output tokens for each prompt/image pair. + + ``Mandatory`` Args: + :model (torch.nn.module): Original ``PyTorch`` model + :images (List[PIL.Image]): List of input images + :queries (List[str]): List of input queries + :generation_config (dict): Generation configuration parameters + + Return: + :List[numpy.ndarray]: List of generated output tokens for each prompt + """ + generated_ids = [] + for idx, (image, query) in enumerate(zip(images, queries)): + inputs = self.processor.process(images=[image], text=query) + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + outputs = model.generate_from_batch( + inputs, generation_config, tokenizer=self.processor.tokenizer, do_sample=False + ) + + offset_output = outputs[0, inputs["input_ids"].size(1) :] + + py_output = self.processor.tokenizer.decode(offset_output, skip_special_tokens=True).strip() + print(f"Original HF Model Outputs (Torch CPU) for prompt {idx}:") + print("Completion:", repr(py_output)) + generated_ids.append(offset_output) + return generated_ids diff --git a/examples/image_text_to_text/models/granite_vision/continuous_batching.py b/examples/image_text_to_text/models/granite_vision/continuous_batching.py new file mode 100644 index 000000000..22c4270bc --- /dev/null +++ b/examples/image_text_to_text/models/granite_vision/continuous_batching.py @@ -0,0 +1,67 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "ibm-granite/granite-vision-3.2-2b" +config = AutoConfig.from_pretrained(model_id) +config.text_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=5500, + ctx_len=6000, + num_cores=16, + num_devices=4, + img_size=384, + mxfp6_matmul=False, +) + +image_urls = [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", +] + +prompts = [ + "Describe the image", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=10, + image_height=1610, + image_width=1109, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output.generated_texts) diff --git a/examples/image_text_to_text/models/internvl/continuous_batching.py b/examples/image_text_to_text/models/internvl/continuous_batching.py new file mode 100644 index 000000000..ca3e0ede3 --- /dev/null +++ b/examples/image_text_to_text/models/internvl/continuous_batching.py @@ -0,0 +1,100 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.test_utils import InternProcessor + +model_id = "OpenGVLab/InternVL2_5-1B" +config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) +# For Testing Purpose Only +config.llm_config.num_hidden_layers = 2 +config.vision_config.num_hidden_layers = 2 + +# The original Intern-VL model, despite being multimodal, is loaded using `AutoModelForCausalLM` in Huggingface. +# To maintain compatibility, we load this model using `QEFFAutoModelForCausalLM`. +model_hf = AutoModelForCausalLM.from_pretrained( + model_id, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=config, +) + +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False) +processor = InternProcessor(model_hf, tokenizer) + + +continuous_batching = True +if continuous_batching: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + trust_remote_code=True, + ) + + qeff_model.compile( + num_patches=13, # Set num_patches according to image_height and image_width, default is 13 (747 x 1000) + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) +else: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config, trust_remote_code=True + ) + + qeff_model.compile( + num_patches=13, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[0, 1, 2, 3], + generation_len=10, + image_height=747, + image_width=1000, +) + +print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) diff --git a/tests/transformers/models/image_text_to_text/test_continuous_batching.py b/tests/transformers/models/image_text_to_text/test_continuous_batching.py new file mode 100644 index 000000000..2f33b7ee8 --- /dev/null +++ b/tests/transformers/models/image_text_to_text/test_continuous_batching.py @@ -0,0 +1,720 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from io import BytesIO +from typing import List + +import pytest +import requests +from PIL import Image +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoProcessor, + AutoTokenizer, + GenerationConfig, +) + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText +from QEfficient.utils import hf_download +from QEfficient.utils._utils import get_num_layers_vlm +from QEfficient.utils.device_utils import get_available_device_id +from QEfficient.utils.run_utils import ApiRunnerInternVL, ApiRunnerMolmo, ApiRunnerVlm +from QEfficient.utils.test_utils import InternProcessor + +NEW_GENERATION_TOKENS = 10 + +# TODO: Add CB support for kv_offload=False case +test_models_config = [ + # CONFIG PARAMS NEEDED FOR A MODEL TO BE TESTED + # ( + # model_name, + # kv_offload, + # batch_size, + # prompt_len, + # ctx_len, + # img_size, + # img_url_list", + # text_prompt_list, + # number of layers of the model, + # full_batch_size + # ), + ( + "llava-hf/llava-1.5-7b-hf", + True, + 1, + 784, + 1024, + 336, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 1, + 4, + ), + # Disabled in CI due to performance issues + # ( + # "meta-llama/Llama-4-Scout-17B-16E-Instruct", + # True, + # 1, + # 128, + # 3072, + # 336, + # ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg",], + # ["Can you describe the image in detail?", + # "What are the objects in the image?", + # "What is the main subject of the image?", + # "What colors are predominant in the image?"], + # 4, + # 4, + # ), + ( + "google/gemma-3-4b-it", + True, + 1, + 128, + 3072, + 896, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 1, + 4, + ), + ( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + True, + 1, + 128, + 4096, + 1540, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 1, + 4, + ), + ( + "Qwen/Qwen2.5-VL-3B-Instruct", + True, + 1, + 128, + 4096, + 1540, + [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 2, + 4, + ), + # ( + # "meta-llama/Llama-3.2-11B-Vision-Instruct", + # True, + # 1, + # 32, + # 512, + # 560, + # ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg",], + # ["Can you describe the image in detail?", + # "What are the objects in the image?", + # "What is the main subject of the image?", + # "What colors are predominant in the image?"], + # 7, + # 4, + # ), +] + +intern_model_config = [ + ( + "OpenGVLab/InternVL2_5-1B", + True, + 1, + 384, + 512, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 2, + 4, + ), + ( + "OpenGVLab/InternVL3_5-1B", + True, + 1, + 384, + 512, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 2, + 4, + ), +] + +molmo_model_config = [ + # Disabled in CI due to HF issues + # ( + # "allenai/Molmo-7B-D-0924", + # True, + # 1, + # 128, + # 4096, + # ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg",], + # ["Can you describe the image in detail?", + # "What are the objects in the image?", + # "What is the main subject of the image?", + # "What colors are predominant in the image?"], + # 2, + # 4, + # ), +] + + +def load_image_text_to_text_model(model_config): + model_path = hf_download( + repo_id=model_config._name_or_path, + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + try: + model_hf = AutoModelForImageTextToText.from_pretrained( + model_path, + low_cpu_mem_usage=False, + config=model_config, + ) + except ValueError: + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=model_config, + ) + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def set_num_layers(config, n_layer=1): + ## -1 indicates use all the layers of the model. + if n_layer == -1: + return config + elif hasattr(config, "model_type") and "mllama" in config.model_type: + config.text_config.num_hidden_layers = n_layer + config.text_config.cross_attention_layers = [ + x for x in config.text_config.cross_attention_layers if x < n_layer + ] + elif hasattr(config, "text_config"): + config.text_config.num_hidden_layers = n_layer + config.vision_config.num_hidden_layers = n_layer + elif hasattr(config, "llm_config"): + config.llm_config.num_hidden_layers = n_layer + config.vision_config.num_hidden_layers = n_layer + else: + config.num_hidden_layers = n_layer + return config + + +def check_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name: str, + img_size: int, + image_urls: List[str], + queries: List[str], + prompt_len: int, + ctx_len: int, + max_gen_len: int = 20, + batch_size: int = 1, + n_layer: int = 1, + num_devices: int = 1, + full_batch_size: int = 4, + kv_offload: bool = True, +): + model_config = {"model_name": model_name} + model_config["img_size"] = img_size + config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True) + config = set_num_layers(config, n_layer=n_layer) + model_hf, _ = load_image_text_to_text_model(config) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, padding=True) + + n_layer = get_num_layers_vlm(config) + + image_height = None + image_width = None + + images = [] + for img_url in image_urls: + image = Image.open(requests.get(img_url, stream=True).raw) + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + image_height = 1540 + image_width = 1540 + image = image.resize((image_height, image_width)) + images.append(image) + + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": queries[0]}, + {"type": "image"}, + ], + }, + ] + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + api_runner = ApiRunnerVlm( + batch_size, + processor, + config, + images[0], + conversation, + prompt, + prompt_len, + ctx_len, + max_gen_len, + n_layer, + ) + + # For same prompt + image_list = [images[0]] * full_batch_size + prompt_list = [queries[0]] * full_batch_size + + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, image_list, prompt_list) + + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_config["model_name"], + kv_offload=kv_offload, + config=config, + continuous_batching=True, + ) + + qeff_model.export() + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qeff_model.compile( + img_size=model_config["img_size"], + num_cores=16, + num_devices=num_devices, + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + mxfp6_matmul=False, + ) + + print("QPC Outputs (QAIC):") + exec_info = qeff_model.generate( + tokenizer=processor.tokenizer, + processor=processor, + images=[image_urls[0]] * full_batch_size, + prompts=prompt_list, + generation_len=max_gen_len, + image_height=image_height, + image_width=image_width, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching with same prompt:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for same prompts" + ) + + # For different prompts + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, images, queries) + + print("QPC Outputs (QAIC):") + exec_info = qeff_model.generate( + tokenizer=processor.tokenizer, + processor=processor, + images=image_urls, + prompts=queries, + generation_len=max_gen_len, + image_height=image_height, + image_width=image_width, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching with different prompt:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for different prompts" + ) + return + + +def check_molmo_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name: str, + image_urls: List[str], + queries: List[str], + prompt_len: int, + ctx_len: int, + max_gen_len: int = 20, + batch_size: int = 1, + n_layer: int = 1, + num_devices: int = 1, + full_batch_size: int = 4, + kv_offload: bool = True, +): + model_config = {"model_name": model_name} + + config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True) + config._attn_implementation = "eager" + config = set_num_layers(config, n_layer=n_layer) + model_hf, _ = load_image_text_to_text_model(config) + n_layer = (n_layer, n_layer) + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, padding=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + images = [] + for img_url in image_urls: + img = requests.get(img_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + image = image.resize((536, 354)) + images.append(image) + + api_runner = ApiRunnerMolmo( + batch_size, + processor, + config, + images[0], + queries[0], + prompt_len, + ctx_len, + max_gen_len, + n_layer, + ) + + generation_config = GenerationConfig(max_new_tokens=NEW_GENERATION_TOKENS, stop_strings="<|endoftext|>") + + # For same prompt + image_list = [images[0]] * full_batch_size + prompt_list = [queries[0]] * full_batch_size + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, image_list, prompt_list, generation_config) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + attn_implementation="eager", + kv_offload=kv_offload, + config=config, + continuous_batching=True, + ) + + qeff_model.export() + + qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_devices=4, + batch_size=1, + full_batch_size=full_batch_size, + mxfp6_matmul=False, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + exec_info = qeff_model.generate( + tokenizer=tokenizer, + processor=processor, + images=[image_urls[0]] * full_batch_size, + prompts=prompt_list, + generation_len=max_gen_len, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching with same prompt:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for same prompts" + ) + + # For different prompts + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, images, queries, generation_config) + exec_info = qeff_model.generate( + tokenizer=tokenizer, + processor=processor, + images=image_urls, + prompts=queries, + generation_len=max_gen_len, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching with different prompt:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for different prompts" + ) + return + + +def check_intern_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name: str, + image_urls: str, + queries: str, + prompt_len: int, + ctx_len: int, + max_gen_len: int = 20, + batch_size: int = 1, + n_layer: int = 1, + kv_offload: bool = True, + num_devices: int = 1, + full_batch_size: int = 4, +): + model_config = {"model_name": model_name} + + config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True) + config._attn_implementation = "eager" + config = set_num_layers(config, n_layer=n_layer) + model_hf = AutoModelForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=config, + ) + n_layer = get_num_layers_vlm(config) + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + processor = InternProcessor(model_hf, tokenizer) + + generation_config = dict(max_new_tokens=max_gen_len, do_sample=False) + generation_config["eos_token_id"] = tokenizer.convert_tokens_to_ids("<|im_end|>\n".strip()) + + images = [] + for img_url in image_urls: + img = requests.get(img_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + image = image.resize((448, 448)) + images.append(image) + + api_runner = ApiRunnerInternVL( + batch_size, + processor, + config, + images[0], + queries[0], + prompt_len, + ctx_len, + max_gen_len, + n_layer, + ) + + # For same prompt + image_list = [images[0]] * full_batch_size + prompt_list = [queries[0]] * full_batch_size + + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, image_list, prompt_list) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + ) + + qeff_model.export() + + qeff_model.compile( + num_patches=1, + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_devices=4, + batch_size=1, + full_batch_size=full_batch_size, + mxfp6_matmul=False, + ) + + exec_info = qeff_model.generate( + tokenizer=tokenizer, + processor=processor, + images=[image_urls[0]] * full_batch_size, + prompts=prompt_list, + generation_len=max_gen_len, + image_height=448, + image_width=448, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching for same prompts:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for same prompts" + ) + + # For different prompts + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, images, queries) + + exec_info = qeff_model.generate( + tokenizer=tokenizer, + processor=processor, + images=image_urls, + prompts=queries, + generation_len=max_gen_len, + image_height=448, + image_width=448, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching for different prompts:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for different prompts" + ) + return + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_urls, queries, n_layer, full_batch_size", + test_models_config, +) +def test_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_urls, queries, n_layer, full_batch_size +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + check_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_size=img_size, + image_urls=img_urls, + queries=queries, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + full_batch_size=full_batch_size, + ) + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_urls, queries, n_layer, full_batch_size", + molmo_model_config, +) +def test_image_text_to_text_molmo_pytorch_vs_ai100_continuous_batching( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_urls, queries, n_layer, full_batch_size +): + check_molmo_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + image_urls=img_urls, + queries=queries, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + full_batch_size=full_batch_size, + ) + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, queries, n_layer, full_batch_size", + intern_model_config, +) +def test_image_text_to_text_intern_pytorch_vs_ai100_continuous_batching( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, queries, n_layer, full_batch_size +): + check_intern_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + image_urls=img_url, + queries=queries, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + full_batch_size=full_batch_size, + ) diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py similarity index 100% rename from tests/transformers/models/test_image_text_to_text_models.py rename to tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py