From 409da2498f20e3e084b47d8af6ca91d79d8519ee Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 23 Oct 2025 16:35:26 -0700 Subject: [PATCH 01/20] Extend on-device sampling support for dual QPC VLMs Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 122 +++++++++++++++++- .../transformers/models/pytorch_transforms.py | 4 + QEfficient/transformers/sampler/sampler.py | 56 +++++--- 3 files changed, 164 insertions(+), 18 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 60f60c768..8b021314e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -721,7 +721,12 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, **kwargs): + def __init__( + self, + model, + qaic_config: Optional[dict] = None, + **kwargs + ): """ Initializes the language decoder component for multimodal models. @@ -729,12 +734,24 @@ def __init__(self, model, **kwargs): ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + Only the following keys are supported by the text model of the dual QPC multimodal model: + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + Additional keys will be ignored. **kwargs : Additional keyword arguments passed to the base class constructor. """ super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.model.qaic_config = qaic_config + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs) def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): """ @@ -758,10 +775,95 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt str Path to the generated ONNX graph file for the language decoder. """ + if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): + inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) + def get_sampling_inputs_and_outputs( + self, + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + ): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" + + logits_index = output_names.index("logits") + output_names[logits_index] = "next_tokens" + + example_inputs["last_accepted_output_tokens"] = torch.zeros( + (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 + ) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + ) + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes + def compile( self, compile_dir, @@ -1499,6 +1601,8 @@ def __init__( """ if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if kwargs.pop("qaic_config", None): + raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) # to handle internvl models @@ -2023,6 +2127,7 @@ def from_pretrained( pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, continuous_batching: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -2036,6 +2141,12 @@ def from_pretrained( If True, uses the dual QPC approach (vision encoder KV offloaded). If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + Only the following keys are supported by the text model of the dual QPC multimodal model: + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + Additional keys will be ignored. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. @@ -2063,11 +2174,14 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + if qaic_config is not None: + qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( model, kv_offload=kv_offload, continuous_batching=continuous_batching, + qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs, ) @@ -2273,7 +2387,11 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs ) return cls( model, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..c750a8c66 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -289,6 +289,7 @@ QEffGrok1MultiHeadAttention, ) from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternDecoderWrapper, QEffInternVisionEmbeddings, QEffInternVLModel, ) @@ -392,6 +393,7 @@ QEffQwen2_5_VLModel, QEffQwen2_5_VLTextModel, QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_DecoderWrapper, QEffQwen_2_5_vl_ForConditionalGeneration, ) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( @@ -707,10 +709,12 @@ class SamplerTransform: QEffGPTJForCausalLM, QEffGraniteForCausalLM, QEffGraniteMoeForCausalLM, + QEffInternDecoderWrapper, QEffLlamaForCausalLM, QEffMptForCausalLM, QEffPhi3ForCausalLM, QEffQwen2ForCausalLM, + QEffQwen_2_5_vl_DecoderWrapper, } @classmethod diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 96846e712..4a9aa6034 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -122,6 +124,8 @@ def sampler_forward( top_ps: Optional[torch.Tensor] = None, min_ps: Optional[torch.Tensor] = None, random_numbers: Optional[torch.Tensor] = None, + vision_embeds: Optional[torch.Tensor] = None, + image_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -170,20 +174,36 @@ def sampler_forward( Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. """ - - outputs = self.old_forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - batch_index=batch_index, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) + if vision_embeds is not None: + logits, vision_embeds, image_idx, past_key_values = self.old_forward( + input_ids=input_ids, + vision_embeds=vision_embeds, + position_ids=position_ids, + image_idx=image_idx, + past_key_values=past_key_values + ) + outputs = dict( + logits=logits, + vision_embeds=vision_embeds, + image_idx=image_idx, + past_key_values=past_key_values + ) + if position_ids.dim() == 3: # For models using m-rope + position_ids = position_ids[0] + else: + outputs = self.old_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) logits = outputs.get("logits", None) assert logits is not None, f"{self.model.__class__.__name__} does not return logits." @@ -230,7 +250,9 @@ def sampler_forward( return SamplerOutput( probs=None, next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) @@ -314,7 +336,9 @@ def sampler_forward( return SamplerOutput( probs=probs, next_tokens=next_tokens, # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) From e06e1758bad19feae3dbb7c38a2f349c82b0a585 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 30 Oct 2025 00:04:01 -0700 Subject: [PATCH 02/20] Fix random_numbers shape Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 25 ++++++++----------- QEfficient/transformers/sampler/sampler.py | 22 ++++++---------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8b021314e..6168f4492 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -721,12 +721,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__( - self, - model, - qaic_config: Optional[dict] = None, - **kwargs - ): + def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -735,7 +730,7 @@ def __init__( model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. qaic_config : dict, optional - A dictionary for QAIC-specific configurations. + A dictionary for QAIC-specific configurations. Only the following keys are supported by the text model of the dual QPC multimodal model: - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. @@ -776,7 +771,9 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the language decoder. """ if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes) + inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + inputs, output_names, dynamic_axes + ) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) @@ -807,7 +804,7 @@ def get_sampling_inputs_and_outputs( sampling-related parameters. """ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - + assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" logits_index = output_names.index("logits") @@ -859,7 +856,7 @@ def get_sampling_inputs_and_outputs( example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS dynamic_axes["min_ps"] = {0: "batch_size"} - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes @@ -2142,7 +2139,7 @@ def from_pretrained( If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). qaic_config : dict, optional - A dictionary for QAIC-specific configurations. + A dictionary for QAIC-specific configurations. Only the following keys are supported by the text model of the dual QPC multimodal model: - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. @@ -2181,7 +2178,7 @@ def from_pretrained( model, kv_offload=kv_offload, continuous_batching=continuous_batching, - qaic_config=qaic_config, + qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs, ) @@ -2391,7 +2388,7 @@ def from_pretrained( kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, - **kwargs + **kwargs, ) return cls( model, @@ -2594,7 +2591,7 @@ def get_sampling_inputs_and_outputs( example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS dynamic_axes["min_ps"] = {0: "batch_size"} - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 4a9aa6034..a15e156ff 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,8 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None - vision_embeds: Optional[torch.FloatTensor] = None # For VLMs - image_idx: Optional[torch.IntTensor] = None # for VLMs + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -176,19 +176,14 @@ def sampler_forward( """ if vision_embeds is not None: logits, vision_embeds, image_idx, past_key_values = self.old_forward( - input_ids=input_ids, - vision_embeds=vision_embeds, - position_ids=position_ids, - image_idx=image_idx, - past_key_values=past_key_values - ) - outputs = dict( - logits=logits, + input_ids=input_ids, vision_embeds=vision_embeds, + position_ids=position_ids, image_idx=image_idx, - past_key_values=past_key_values + past_key_values=past_key_values, ) - if position_ids.dim() == 3: # For models using m-rope + outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) + if position_ids.dim() == 3: # For models using m-rope position_ids = position_ids[0] else: outputs = self.old_forward( @@ -322,9 +317,8 @@ def sampler_forward( ) # (batch_size, spec_length, vocab_size) # Random Sampling - topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids) gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick - y = topk_probs_asc + gumbel_noise + y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids) random_samples_indices = torch.argmax(y, dim=1, keepdim=True) random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1) From 3e242ce85b42e5babcee1ee87ce2dacb0c0565e9 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 30 Oct 2025 00:04:01 -0700 Subject: [PATCH 03/20] Update example with new random sampling logic Signed-off-by: quic-sanising Signed-off-by: sanising --- examples/on_device_sampling.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index 00d8c2430..108e5390e 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -28,6 +28,7 @@ def main(args, **kwargs): if include_sampler is not None: return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) + np.random.seed(int(args.random_number)) sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -36,7 +37,9 @@ def main(args, **kwargs): "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), "min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1), - "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype( + np.float32 + ), } qaic_config = { k: v @@ -110,10 +113,10 @@ def main(args, **kwargs): --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 2. For non-continuous batching: python3.10 examples/on_device_sampling.py \ @@ -130,10 +133,10 @@ def main(args, **kwargs): --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 """ parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") From 1a01d57a9d737c0e06ea1c87cf91ce3408dbb324 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Mon, 10 Nov 2025 16:33:06 -0800 Subject: [PATCH 04/20] Update to align with recent VLM CB changes Signed-off-by: quic-xiyushi --- QEfficient/transformers/models/modeling_auto.py | 17 +++++++++++------ QEfficient/transformers/sampler/sampler.py | 6 +++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 6168f4492..c110b3ce5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -721,7 +721,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): + def __init__(self, model, continuous_batching: bool = False, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -729,6 +729,9 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. + continuous_batching : bool, optional + If True, enables continuous batching mode for future compilation and execution. + This setting must be consistent across `from_pretrained` and `compile` calls. Default is False. qaic_config : dict, optional A dictionary for QAIC-specific configurations. Only the following keys are supported by the text model of the dual QPC multimodal model: @@ -741,6 +744,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.continuous_batching = continuous_batching self.model.qaic_config = qaic_config # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -804,6 +808,7 @@ def get_sampling_inputs_and_outputs( sampling-related parameters. """ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" @@ -816,10 +821,10 @@ def get_sampling_inputs_and_outputs( dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool ) dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "batch_size", + 0: "full_batch_size" if self.continuous_batching else "batch_size", } output_names.append("past_repetition_penalty_buffer_RetainedState") @@ -829,10 +834,10 @@ def get_sampling_inputs_and_outputs( dynamic_axes["repetition_penalties"] = {0: "batch_size"} example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool ) dynamic_axes["past_presence_penalty_buffer"] = { - 0: "batch_size", + 0: "full_batch_size" if self.continuous_batching else "batch_size", } output_names.append("past_presence_penalty_buffer_RetainedState") @@ -981,7 +986,7 @@ def __init__( self.model = model self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs) self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index a15e156ff..1075db784 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -175,13 +175,17 @@ def sampler_forward( Must be in [-1, 1]. """ if vision_embeds is not None: - logits, vision_embeds, image_idx, past_key_values = self.old_forward( + forward_kwargs = dict( input_ids=input_ids, vision_embeds=vision_embeds, position_ids=position_ids, image_idx=image_idx, past_key_values=past_key_values, ) + if batch_index is not None: + forward_kwargs["batch_index"] = batch_index + + logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs) outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) if position_ids.dim() == 3: # For models using m-rope position_ids = position_ids[0] From 30d60618e43b7931d7a2a090e2fb4268510d1337 Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 10 Nov 2025 19:28:38 -0600 Subject: [PATCH 05/20] Update tests with new random sampling logic Signed-off-by: sanising --- tests/transformers/sampler/test_sampler.py | 52 +++++++++++----------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9335e1d91..8d437eee8 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -211,7 +211,7 @@ def test_greedy_sampling( "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), }, ) model_wo_sampler_exec_info = model_wo_sampler.generate( @@ -233,7 +233,6 @@ def test_greedy_sampling( @pytest.mark.on_qaic -@pytest.mark.skip @pytest.mark.parametrize( "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", random_sampling_configs, @@ -291,6 +290,7 @@ def test_random_sampling( # Generate texts from prompts tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) model_w_sampler_exec_info = model_w_sampler.generate( tokenizer=tokenizer, prompts=prompts, @@ -301,11 +301,13 @@ def test_random_sampling( "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype( + np.float32 + ), }, ) model_wo_sampler_exec_info = model_wo_sampler.generate( @@ -319,32 +321,32 @@ def test_random_sampling( # Compare generated texts golden_texts = { - "w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both", + "w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over", "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", } golden_ids = { "w_sampler": [ [ - 21380, + 319, + 3615, 322, - 590, - 25448, - 2927, - 29892, - 19963, - 2654, - 29879, - 470, - 3708, - 2701, - 313, - 29902, - 508, - 30010, - 29873, - 505, - 963, - 1716, + 306, + 626, + 263, + 3005, + 295, + 749, + 9227, + 1058, + 12355, + 267, + 304, + 26987, + 278, + 3186, + 29889, + 2973, + 975, ] ], "wo_sampler": [ From 78ef180c554c413296325f0abdccaf7ab6cafcec Mon Sep 17 00:00:00 2001 From: sanising Date: Tue, 11 Nov 2025 14:44:51 -0600 Subject: [PATCH 06/20] Add code to perform guided decoding Signed-off-by: sanising --- QEfficient/transformers/sampler/sampler.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 1075db784..8c0f7b7b9 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -126,6 +126,7 @@ def sampler_forward( random_numbers: Optional[torch.Tensor] = None, vision_embeds: Optional[torch.Tensor] = None, image_idx: Optional[torch.Tensor] = None, + bitmask: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -173,6 +174,11 @@ def sampler_forward( random_numbers (`torch.Tensor`, *optional*): Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. + + bitmask (`torch.Tensor`, *optional*): + A boolean mask used to guide token-level filtering during decoding. Each + element of this tensor indicates whether the corresponding token should be + kept (1) or masked (0). Shape: (batch_size, vocab_size) """ if vision_embeds is not None: forward_kwargs = dict( @@ -216,6 +222,13 @@ def sampler_forward( batch_index = torch.arange(batch_size).view(-1, 1) batch_index_reshaped = batch_index.view(-1) + + # Guided decoding + if (bitmask != 1).any(): + assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding" + # Mask logits where bitmask is 0 with -inf + logits = torch.where(bitmask == 1, logits, torch.finfo(torch.float16).min) + # Prefill past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path( input_ids=input_ids, From 1fafcdb50ac225db5335f281bf2bdd279c90bfba Mon Sep 17 00:00:00 2001 From: sanising Date: Tue, 11 Nov 2025 18:58:20 -0600 Subject: [PATCH 07/20] Add bitmask to example inputs and dynamic axes Signed-off-by: sanising --- QEfficient/transformers/models/modeling_auto.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c110b3ce5..38bceb68f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -864,6 +864,9 @@ def get_sampling_inputs_and_outputs( example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} + example_inputs["bitmask"] = torch.ones((bs, self.model.language_model.config.vocab_size), dtype=torch.bool) + dynamic_axes["bitmask"] = {0: "batch_size"} + return example_inputs, output_names, dynamic_axes def compile( @@ -2599,6 +2602,9 @@ def get_sampling_inputs_and_outputs( example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} + example_inputs["bitmask"] = torch.ones((bs, self.model.config.vocab_size), dtype=torch.bool) + dynamic_axes["bitmask"] = {0: "batch_size"} + return example_inputs, output_names, dynamic_axes def build_prefill_specialization( From 18ab856b7eb66afffe6e17c905223bb523be5f1a Mon Sep 17 00:00:00 2001 From: sanising Date: Tue, 11 Nov 2025 20:06:22 -0600 Subject: [PATCH 08/20] Rename bitmask to token_bitmasks Signed-off-by: sanising --- QEfficient/transformers/models/modeling_auto.py | 10 ++++++---- QEfficient/transformers/sampler/sampler.py | 12 ++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 38bceb68f..a71c4905f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -864,8 +864,10 @@ def get_sampling_inputs_and_outputs( example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} - example_inputs["bitmask"] = torch.ones((bs, self.model.language_model.config.vocab_size), dtype=torch.bool) - dynamic_axes["bitmask"] = {0: "batch_size"} + example_inputs["token_bitmasks"] = torch.ones( + (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["token_bitmasks"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes @@ -2602,8 +2604,8 @@ def get_sampling_inputs_and_outputs( example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} - example_inputs["bitmask"] = torch.ones((bs, self.model.config.vocab_size), dtype=torch.bool) - dynamic_axes["bitmask"] = {0: "batch_size"} + example_inputs["token_bitmasks"] = torch.zeros((bs, self.model.config.vocab_size), dtype=torch.bool) + dynamic_axes["token_bitmasks"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 8c0f7b7b9..5d6a8b8e2 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -126,7 +126,7 @@ def sampler_forward( random_numbers: Optional[torch.Tensor] = None, vision_embeds: Optional[torch.Tensor] = None, image_idx: Optional[torch.Tensor] = None, - bitmask: Optional[torch.Tensor] = None, + token_bitmasks: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -175,8 +175,8 @@ def sampler_forward( Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. - bitmask (`torch.Tensor`, *optional*): - A boolean mask used to guide token-level filtering during decoding. Each + token_bitmasks (`torch.Tensor`, *optional*): + Boolean mask used to guide token-level filtering during decoding. Each element of this tensor indicates whether the corresponding token should be kept (1) or masked (0). Shape: (batch_size, vocab_size) """ @@ -224,10 +224,10 @@ def sampler_forward( batch_index_reshaped = batch_index.view(-1) # Guided decoding - if (bitmask != 1).any(): + if (token_bitmasks != 1).any(): assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding" - # Mask logits where bitmask is 0 with -inf - logits = torch.where(bitmask == 1, logits, torch.finfo(torch.float16).min) + # Mask logits where token_bitmasks is 0 with -inf + logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min) # Prefill past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path( From b1c049c022ed28a9e90ad6abdb5bd2de35e67230 Mon Sep 17 00:00:00 2001 From: sanising Date: Wed, 12 Nov 2025 12:30:04 -0600 Subject: [PATCH 09/20] Fix typo Signed-off-by: sanising --- QEfficient/transformers/models/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index a71c4905f..b97b48528 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -864,7 +864,7 @@ def get_sampling_inputs_and_outputs( example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} - example_inputs["token_bitmasks"] = torch.ones( + example_inputs["token_bitmasks"] = torch.zeros( (bs, self.model.language_model.config.vocab_size), dtype=torch.bool ) dynamic_axes["token_bitmasks"] = {0: "batch_size"} From 151549711eb0e7931d8c73400d2365f3c60e8de7 Mon Sep 17 00:00:00 2001 From: sanising Date: Tue, 18 Nov 2025 18:43:08 -0600 Subject: [PATCH 10/20] Add flag to enable guided decoding Signed-off-by: sanising --- .../transformers/models/modeling_auto.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index a36ef9470..31d38c423 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -738,6 +738,8 @@ def __init__(self, model, continuous_batching: bool = False, qaic_config: Option Only the following keys are supported by the text model of the dual QPC multimodal model: - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. Additional keys will be ignored. **kwargs : Additional keyword arguments passed to the base class constructor. @@ -865,10 +867,11 @@ def get_sampling_inputs_and_outputs( example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} - example_inputs["token_bitmasks"] = torch.zeros( - (bs, self.model.language_model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["token_bitmasks"] = {0: "batch_size"} + if self.model.qaic_config.get("include_guided_decoding", False): + example_inputs["token_bitmasks"] = torch.zeros( + (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["token_bitmasks"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes @@ -2271,6 +2274,8 @@ def from_pretrained( Only the following keys are supported by the text model of the dual QPC multimodal model: - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. Additional keys will be ignored. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. @@ -2376,6 +2381,8 @@ def __init__( - **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens. For Speculative Decoding Target Language Models, this is always True. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. **kwargs : Additional keyword arguments passed to the base class constructor. @@ -2478,6 +2485,8 @@ def from_pretrained( and ``return_pdfs=False`` for regular model. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. The values provided in ``top_ks`` tensor must be less than this maximum limit. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. *args : Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. @@ -2729,8 +2738,9 @@ def get_sampling_inputs_and_outputs( example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} - example_inputs["token_bitmasks"] = torch.zeros((bs, self.model.config.vocab_size), dtype=torch.bool) - dynamic_axes["token_bitmasks"] = {0: "batch_size"} + if self.model.qaic_config.get("include_guided_decoding", False): + example_inputs["token_bitmasks"] = torch.zeros((bs, self.model.config.vocab_size), dtype=torch.bool) + dynamic_axes["token_bitmasks"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes From 97e4bafec0f006d2eaa024b6c3e62f4819e847bc Mon Sep 17 00:00:00 2001 From: sanising Date: Tue, 18 Nov 2025 19:57:11 -0600 Subject: [PATCH 11/20] Add flag to enable guided decoding Signed-off-by: sanising --- .../generation/text_generation_inference.py | 16 ++++++-- QEfficient/utils/sampler_utils.py | 6 ++- examples/on_device_sampling.py | 39 +++++++++++++++++-- 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 7da2300d6..4fb77f272 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -329,6 +329,7 @@ def cloud_ai_100_exec_kv( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ): """ @@ -356,6 +357,8 @@ def cloud_ai_100_exec_kv( next tokens. For Speculative Decoding Target Language Model, `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. + :include_guided_decoding (bool, default=False): If True, enables guided token-level filtering + during decoding. Only works when `include_sampler`=True. sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend. The dictionary should contain the following keys: `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, @@ -394,6 +397,7 @@ def cloud_ai_100_exec_kv( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) @@ -442,6 +446,7 @@ def __init__( is_tlm: Optional[int] = None, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, activate: bool = True, ) -> None: @@ -451,6 +456,7 @@ def __init__( self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs + self.include_guided_decoding = include_guided_decoding self.sampling_params = sampling_params self._qpc_path = qpc_path # Store qpc_path for later use @@ -461,7 +467,9 @@ def __init__( # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( - session_inputs=set(self._session.input_names), include_sampler=include_sampler + session_inputs=set(self._session.input_names), + include_sampler=include_sampler, + include_guided_decoding=include_guided_decoding, ) # Fetch the variables from the QPC @@ -628,7 +636,7 @@ def prepare_decode_inputs(self): decode_inputs["batch_index"] = self.batch_index if self.include_sampler: decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if self.batch_index is not None: decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()] else: @@ -795,7 +803,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: inputs["last_accepted_output_tokens"] = inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if decode_batch_id is not None: inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: @@ -1067,6 +1075,7 @@ def __init__( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( @@ -1082,6 +1091,7 @@ def __init__( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) self._full_batch_size = self._qaic_model.full_batch_size diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py index 6fb1b326f..fd743b1e8 100644 --- a/QEfficient/utils/sampler_utils.py +++ b/QEfficient/utils/sampler_utils.py @@ -11,7 +11,9 @@ from QEfficient.utils.logging_utils import logger -def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[bool] = None) -> bool: +def validate_sampler_inputs( + session_inputs: Set[str], include_sampler: Optional[bool] = None, include_guided_decoding: Optional[bool] = None +) -> bool: """ Validates whether the `QAICInferenceSession` inputs match inputs required for on-device sampling. @@ -28,7 +30,7 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ValueError if partial support is detected or if user intent conflicts with QPC capabilities. """ - sampler_inputs = Constants.SAMPLER_INPUTS + sampler_inputs = Constants.SAMPLER_INPUTS | ({"token_bitmasks"} if include_guided_decoding else set()) count = len(sampler_inputs & session_inputs) session_includes_sampler = True diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index 108e5390e..99712d091 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -21,6 +21,7 @@ def main(args, **kwargs): include_sampler = None return_pdfs = None max_top_k_ids = None + include_guided_decoding = None sampling_params = None bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size if args.override_qaic_config is not None: @@ -29,6 +30,7 @@ def main(args, **kwargs): return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) np.random.seed(int(args.random_number)) + include_guided_decoding = args.override_qaic_config.get("aic_include_guided_decoding", None) == "true" sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -47,13 +49,12 @@ def main(args, **kwargs): "include_sampler": include_sampler, "return_pdfs": return_pdfs, "max_top_k_ids": max_top_k_ids, + "include_guided_decoding": include_guided_decoding, }.items() if v is not None } print("qaic_config:") pprint(qaic_config) - print("sampling_params:") - pprint(sampling_params) # Load model with On Device Sampler enabled qeff_model = AutoModelForCausalLM.from_pretrained( @@ -63,6 +64,15 @@ def main(args, **kwargs): ) print(f"{args.model_name} optimized for AI 100 \n", qeff_model) + if include_guided_decoding: + # Ideally this should come from a logits processor like xgrammar, but for the sake of the + # example, we generate a random bitmask + sampling_params.update( + {"token_bitmasks": np.random.choice([True, False], size=(bs, qeff_model.model.config.vocab_size))} + ) + print("sampling_params:") + pprint(sampling_params) + # Compile the model for inference generated_qpc_path = qeff_model.compile( prefill_seq_len=args.prompt_len, @@ -91,6 +101,7 @@ def main(args, **kwargs): generation_len=args.generation_len, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) @@ -109,7 +120,7 @@ def main(args, **kwargs): --num-cores 16 \ --mxint8-kv-cache \ --mxfp6-matmul \ - --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \ --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ @@ -129,7 +140,27 @@ def main(args, **kwargs): --num-cores 16 \ --mxint8-kv-cache \ --mxfp6-matmul \ - --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \ + --repetition-penalty 1.9 \ + --presence-penalty 0.8 \ + --temperature 0.67 \ + --top-k 54 \ + --top-p 0.89 \ + --min-p 0.6 \ + --random-number 26 + + 3. With guided decoding: + python3.10 examples/on_device_sampling.py \ + --model-name 'meta-llama/Llama-3.1-8B' \ + --prompt-len 128 \ + --ctx-len 256 \ + --generation-len 20 \ + --full-batch-size 2 \ + --device-group [0,1,2,3] \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:true" \ --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ From 7b7677bbafb8d28a719ec4ee35ce696aa684098c Mon Sep 17 00:00:00 2001 From: sanising Date: Wed, 19 Nov 2025 14:05:58 -0600 Subject: [PATCH 12/20] Update test_sampler_transform for guided decoding Signed-off-by: sanising --- tests/transformers/sampler/test_sampler.py | 34 ++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 8d437eee8..9d9a032c1 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -80,6 +80,17 @@ def test_sampler_transform( "max_top_k_ids": 512, }, ) + model_w_sampler_w_guided_decoding = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + num_hidden_layers=2, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + "include_guided_decoding": True, + }, + ) model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( model, continuous_batching=True, @@ -99,6 +110,16 @@ def test_sampler_transform( mxint8_kv_cache=True, mxfp6_matmul=True, ) + model_w_sampler_w_guided_decoding_qpc_path: str = model_w_sampler_w_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) model_wo_sampler_qpc_path: str = model_wo_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, @@ -112,6 +133,7 @@ def test_sampler_transform( # Init qaic session model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) + model_w_sampler_w_guided_decoding_session = QAICInferenceSession(model_w_sampler_w_guided_decoding_qpc_path) model_wo_sampler_session = QAICInferenceSession(model_wo_sampler_qpc_path) # Skip inputs/outputs buffers @@ -119,6 +141,12 @@ def test_sampler_transform( model_w_sampler_session.skip_buffers( set([x for x in model_w_sampler_session.output_names if x.endswith("_RetainedState")]) ) + model_w_sampler_w_guided_decoding_session.skip_buffers( + set([x for x in model_w_sampler_w_guided_decoding_session.input_names if x.startswith("past_")]) + ) + model_w_sampler_w_guided_decoding_session.skip_buffers( + set([x for x in model_w_sampler_w_guided_decoding_session.output_names if x.endswith("_RetainedState")]) + ) model_wo_sampler_session.skip_buffers( set([x for x in model_wo_sampler_session.input_names if x.startswith("past_")]) ) @@ -132,9 +160,15 @@ def test_sampler_transform( assert input_name in model_w_sampler_session.input_names, ( f"Sampler input {input_name} not found in QPC compiled with On Device Sampler" ) + assert input_name in model_w_sampler_w_guided_decoding_session.input_names, ( + f"Sampler input {input_name} not found in QPC compiled with On Device Sampler and Guided Decoding" + ) assert input_name not in model_wo_sampler_session.input_names, ( f"Sampler input {input_name} found in QPC compiled without On Device Sampler" ) + assert "token_bitmasks" in model_w_sampler_w_guided_decoding_session.input_names, ( + "Sampler input token_bitmasks not found in QPC compiled with On Device Sampler and Guided Decoding" + ) @pytest.mark.on_qaic From 7cf106e39f7a448aed031cfb66852227348e9215 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Wed, 19 Nov 2025 10:53:24 -0800 Subject: [PATCH 13/20] Refactor Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 209 ++---------------- QEfficient/utils/sampler_utils.py | 91 +++++++- 2 files changed, 114 insertions(+), 186 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index a1a333317..242063ee9 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -61,6 +61,7 @@ ) from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger +from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs class QEFFTransformersBase(QEFFBaseModel): @@ -730,28 +731,12 @@ def __init__(self, model, continuous_batching: bool = False, qaic_config: Option ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. - continuous_batching : bool, optional - If True, enables continuous batching mode for future compilation and execution. - This setting must be consistent across `from_pretrained` and `compile` calls. Default is False. - qaic_config : dict, optional - A dictionary for QAIC-specific configurations. - Only the following keys are supported by the text model of the dual QPC multimodal model: - - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. - Additional keys will be ignored. **kwargs : Additional keyword arguments passed to the base class constructor. """ super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - self.continuous_batching = continuous_batching - self.model.qaic_config = qaic_config - # ---Sampling--- - # Note: SamplerTransform should be applied after all other transforms - # are done. The role of the sampler is to just add nodes at the output of the - # previous transform function. - self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs) def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): """ @@ -775,98 +760,10 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt str Path to the generated ONNX graph file for the language decoder. """ - if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( - inputs, output_names, dynamic_axes - ) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) - def get_sampling_inputs_and_outputs( - self, - example_inputs: Dict[str, torch.Tensor], - output_names: List[str], - dynamic_axes: Dict[str, Dict[int, str]], - ): - """ - Updates the example inputs, output names, and dynamic axes to include - parameters relevant for on-device sampling during ONNX export. - - Parameters - ---------- - example_inputs : Dict[str, torch.Tensor] - Current dictionary of example inputs. - output_names : List[str] - Current list of output names. - dynamic_axes : Dict[str, Dict[int, str]] - Current dictionary of dynamic axes configurations. - - Returns - ------- - Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] - Updated example inputs, output names, and dynamic axes including - sampling-related parameters. - """ - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" - - logits_index = output_names.index("logits") - output_names[logits_index] = "next_tokens" - - example_inputs["last_accepted_output_tokens"] = torch.zeros( - (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 - ) - dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} - - example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_repetition_penalty_buffer_RetainedState") - - example_inputs["repetition_penalties"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES - ) - dynamic_axes["repetition_penalties"] = {0: "batch_size"} - - example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_presence_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_presence_penalty_buffer_RetainedState") - - example_inputs["presence_penalties"] = ( - torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES - ) - dynamic_axes["presence_penalties"] = {0: "batch_size"} - - example_inputs["temperatures"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES - ) - dynamic_axes["temperatures"] = {0: "batch_size"} - - max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) - example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) - dynamic_axes["top_ks"] = {0: "batch_size"} - - example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS - dynamic_axes["top_ps"] = {0: "batch_size"} - - example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS - dynamic_axes["min_ps"] = {0: "batch_size"} - - example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) - dynamic_axes["random_numbers"] = {0: "batch_size"} - - return example_inputs, output_names, dynamic_axes - def compile( self, compile_dir, @@ -993,7 +890,13 @@ def __init__( self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs) self.continuous_batching = continuous_batching + self.lang_model.model.qaic_config = qaic_config self.input_shapes, self.output_names = None, None + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs) @property def model_name(self) -> str: @@ -1115,6 +1018,19 @@ def export( kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode ) output_names = self.model.get_output_names(kv_offload=True) + if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get( + "include_sampler", False + ): + logits_index = output_names["lang"].index("logits") + output_names["lang"][logits_index] = "next_tokens" + inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs( + example_inputs=inputs["lang"], + output_names=output_names["lang"], + dynamic_axes=dynamic_axes["lang"], + continuous_batching=self.continuous_batching, + vocab_size=self.lang_model.model.config.vocab_size, + qaic_config=self.lang_model.model.qaic_config, + ) self.vision_model.export( inputs["vision"], @@ -2300,7 +2216,6 @@ def from_pretrained( model, kv_offload=kv_offload, continuous_batching=continuous_batching, - qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, **kwargs, @@ -2634,10 +2549,13 @@ def export(self, export_dir: Optional[str] = None) -> str: dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs( example_inputs=example_inputs, output_names=output_names, dynamic_axes=dynamic_axes, + continuous_batching=self.continuous_batching, + vocab_size=self.model.config.vocab_size, + qaic_config=self.model.qaic_config, ) return self._export( @@ -2647,85 +2565,6 @@ def export(self, export_dir: Optional[str] = None) -> str: export_dir=export_dir, ) - def get_sampling_inputs_and_outputs( - self, - example_inputs: Dict[str, torch.Tensor], - output_names: List[str], - dynamic_axes: Dict[str, Dict[int, str]], - ): - """ - Updates the example inputs, output names, and dynamic axes to include - parameters relevant for on-device sampling during ONNX export. - - Parameters - ---------- - example_inputs : Dict[str, torch.Tensor] - Current dictionary of example inputs. - output_names : List[str] - Current list of output names. - dynamic_axes : Dict[str, Dict[int, str]] - Current dictionary of dynamic axes configurations. - - Returns - ------- - Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] - Updated example inputs, output names, and dynamic axes including - sampling-related parameters. - """ - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - example_inputs["last_accepted_output_tokens"] = torch.zeros( - (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 - ) - dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} - - example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_repetition_penalty_buffer_RetainedState") - - example_inputs["repetition_penalties"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES - ) - dynamic_axes["repetition_penalties"] = {0: "batch_size"} - - example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_presence_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_presence_penalty_buffer_RetainedState") - - example_inputs["presence_penalties"] = ( - torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES - ) - dynamic_axes["presence_penalties"] = {0: "batch_size"} - - example_inputs["temperatures"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES - ) - dynamic_axes["temperatures"] = {0: "batch_size"} - - max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) - example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) - dynamic_axes["top_ks"] = {0: "batch_size"} - - example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS - dynamic_axes["top_ps"] = {0: "batch_size"} - - example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS - dynamic_axes["min_ps"] = {0: "batch_size"} - - example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) - dynamic_axes["random_numbers"] = {0: "batch_size"} - - return example_inputs, output_names, dynamic_axes - def build_prefill_specialization( self, prefill_seq_len: int = 32, diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py index 6fb1b326f..0460eeb3a 100644 --- a/QEfficient/utils/sampler_utils.py +++ b/QEfficient/utils/sampler_utils.py @@ -5,8 +5,11 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Set +from typing import Dict, List, Optional, Set +import torch + +from QEfficient.utils import constants from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger @@ -56,3 +59,89 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ) return session_includes_sampler + + +def get_sampling_inputs_and_outputs( + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + continuous_batching: bool, + vocab_size: int, + qaic_config: Dict, +): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + continuous_batching : bool + Whether this model will be used for continuous batching in the future. + vocab_size: int + Vocabulary size for this model. + qaic_config : Dict + QAIC config dictionary. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + example_inputs["last_accepted_output_tokens"] = torch.zeros( + (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 + ) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes From 45aed11cf908615eadb1366416e5df6f5953d48b Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 20 Nov 2025 10:45:58 -0800 Subject: [PATCH 14/20] Add unit tests Signed-off-by: quic-xiyushi --- QEfficient/generation/vlm_generation.py | 13 ++ .../transformers/models/modeling_auto.py | 18 ++- tests/transformers/sampler/test_sampler.py | 142 ++++++++++++++---- 3 files changed, 142 insertions(+), 31 deletions(-) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 5eb91d142..6c028a12f 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -36,6 +36,7 @@ write_io_files, ) from QEfficient.utils import LRUCache +from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger @@ -303,6 +304,13 @@ def _execute_chunked_prefill( prefill_ccl_id = 0 lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + if self.include_sampler: + for op in Constants.SAMPLER_OPS: + if decode_batch_id is not None: + lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + else: + lang_inputs[op] = self.sampling_params[op] + for i in range(num_chunks): input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] position_ids_slice = lang_inputs["position_ids"][ @@ -328,6 +336,11 @@ def _execute_chunked_prefill( chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"] + if self.include_sampler: + chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] + for op in Constants.SAMPLER_OPS: + chunk_inputs[op] = lang_inputs[op] + outputs = self._session.run(chunk_inputs) if "image_idx_output" in outputs: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 242063ee9..2bf81f68f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -881,7 +881,10 @@ def __init__( If `full_batch_size` is provided. """ if kwargs.pop("full_batch_size", None): - raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) self.model = model self.config = model.config @@ -1028,7 +1031,7 @@ def export( output_names=output_names["lang"], dynamic_axes=dynamic_axes["lang"], continuous_batching=self.continuous_batching, - vocab_size=self.lang_model.model.config.vocab_size, + vocab_size=self.config.vocab_size, qaic_config=self.lang_model.model.qaic_config, ) @@ -1235,6 +1238,7 @@ def generate( device_ids: List[int] = None, runtime_ai100: bool = True, generation_len: Optional[int] = None, + **kwargs, ) -> Union[torch.Tensor, np.ndarray]: """ Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. @@ -1293,6 +1297,7 @@ def generate( full_batch_size=fbs, comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + **kwargs, ) # Call generate method @@ -1572,11 +1577,16 @@ def __init__( Raises ------ NotImplementedError - If `full_batch_size` is provided. + If `full_batch_size` is provided or `continuous_batching` is True or `include_sampler` is True. """ if kwargs.pop("full_batch_size", None): + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if kwargs.pop("continuous_batching", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - if kwargs.pop("qaic_config", None): + if qaic_config is not None and qaic_config.pop("include_sampler", False): raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 8d437eee8..d31dfef37 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -5,12 +5,13 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Union +from transformers import AutoConfig, AutoProcessor import numpy as np import pytest -from QEfficient import QEFFAutoModelForCausalLM +from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import load_hf_tokenizer from QEfficient.utils.constants import Constants @@ -24,6 +25,20 @@ 20, # generation_len 2, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "Qwen/Qwen2.5-VL-3B-Instruct", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] greedy_sampling_configs = [ @@ -35,6 +50,20 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "Qwen/Qwen2.5-VL-3B-Instruct", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] random_sampling_configs = [ @@ -46,23 +75,38 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm ), + # pytest.param( + # "Qwen/Qwen2.5-VL-3B-Instruct", # model + # ( + # ["https://picsum.photos/id/237/536/354"] * 2, + # ["Can you describe the image in detail."] * 2, + # ), # images and prompts + # 128, # prefill_seq_len + # 4096, # ctx_len + # 20, # generation_len + # 2, # full_batch_size + # None, # spec_length + # True, # is_vlm + # ), ] @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", sampler_transform_configs, ) def test_sampler_transform( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the @@ -70,45 +114,56 @@ def test_sampler_transform( next tokens and/or probability distributions. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + else: + additional_configs["num_hidden_layers"] = 2 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) - model_w_sampler_qpc_path: str = model_w_sampler.compile( + model_w_sampler_qpc_path = model_w_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) - model_wo_sampler_qpc_path: str = model_wo_sampler.compile( + model_wo_sampler_qpc_path = model_wo_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) + if is_vlm: + model_w_sampler_qpc_path = model_w_sampler_qpc_path[1] + model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1] # Init qaic session model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) @@ -139,40 +194,54 @@ def test_sampler_transform( @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", - greedy_sampling_configs, + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", + sampler_transform_configs, ) def test_greedy_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test greedy sampling with QPC compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + additional_configs["num_hidden_layers"] = 2 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -180,7 +249,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -190,7 +259,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -213,6 +282,7 @@ def test_greedy_sampling( "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -221,6 +291,7 @@ def test_greedy_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts and ids @@ -234,23 +305,36 @@ def test_greedy_sampling( @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", random_sampling_configs, ) def test_random_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test random sampling with QPC compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -258,14 +342,16 @@ def test_random_sampling( "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -273,7 +359,7 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -283,7 +369,7 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -309,6 +395,7 @@ def test_random_sampling( np.float32 ), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -317,6 +404,7 @@ def test_random_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts From 6273ab5c156ee53a6134872c29dd05067e055aa9 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 20 Nov 2025 11:07:52 -0800 Subject: [PATCH 15/20] Clean up Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2bf81f68f..189017507 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -723,7 +723,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, continuous_batching: bool = False, qaic_config: Optional[dict] = None, **kwargs): + def __init__(self, model, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -872,13 +872,10 @@ def __init__( ---------- model : nn.Module The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : - Additional keyword arguments. `full_batch_size` is not supported here. - - Raises - ------ - NotImplementedError - If `full_batch_size` is provided. + Additional keyword arguments. """ if kwargs.pop("full_batch_size", None): continuous_batching = True @@ -891,7 +888,7 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) self.continuous_batching = continuous_batching self.lang_model.model.qaic_config = qaic_config self.input_shapes, self.output_names = None, None @@ -1577,15 +1574,13 @@ def __init__( Raises ------ NotImplementedError - If `full_batch_size` is provided or `continuous_batching` is True or `include_sampler` is True. + If `full_batch_size` is provided or `include_sampler` is True. """ if kwargs.pop("full_batch_size", None): warnings.warn( "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 ) raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - if kwargs.pop("continuous_batching", None): - raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") if qaic_config is not None and qaic_config.pop("include_sampler", False): raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) @@ -2189,10 +2184,6 @@ def from_pretrained( If None, the default behavior of the internal classes is used (typically dual QPC). qaic_config : dict, optional A dictionary for QAIC-specific configurations. - Only the following keys are supported by the text model of the dual QPC multimodal model: - - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. - Additional keys will be ignored. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. From 60312b309c1d12c66d31c919a755ddd578029c27 Mon Sep 17 00:00:00 2001 From: sanising Date: Thu, 20 Nov 2025 13:34:46 -0600 Subject: [PATCH 16/20] Add test for guided decoding Signed-off-by: sanising --- QEfficient/transformers/sampler/sampler.py | 2 +- tests/transformers/sampler/test_sampler.py | 120 ++++++++++++++++++++- 2 files changed, 119 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 5d6a8b8e2..78457d46e 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -224,7 +224,7 @@ def sampler_forward( batch_index_reshaped = batch_index.view(-1) # Guided decoding - if (token_bitmasks != 1).any(): + if token_bitmasks is not None and (token_bitmasks != 1).any(): assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding" # Mask logits where token_bitmasks is 0 with -inf logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9d9a032c1..76b3b6d7f 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -48,6 +48,17 @@ 1, # spec_length ), ] +guided_decoding_configs = [ + pytest.param( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # model + Constants.INPUT_STR * 4, # prompts + 32, # prefill_seq_len + 64, # ctx_len + 20, # generation_len + 4, # full_batch_size + 1, # spec_length + ), +] @pytest.mark.on_qaic @@ -186,7 +197,7 @@ def test_greedy_sampling( spec_length: int, ): """ - Test greedy sampling with QPC compiled with and without On Device Sampling. + Test greedy sampling with QPCs compiled with and without On Device Sampling. """ # Export and compile QEfficient models model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( @@ -281,7 +292,7 @@ def test_random_sampling( spec_length: int, ): """ - Test random sampling with QPC compiled with and without On Device Sampling. + Test random sampling with QPCs compiled with and without On Device Sampling. """ # Export and compile QEfficient models model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( @@ -421,3 +432,108 @@ def test_random_sampling( assert (model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"]).all(), ( "Without sampler generated ids do not match" ) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + guided_decoding_configs, +) +def test_guided_decoding( + model: str, + prompts: List[str], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + spec_length: int, +): + """ + Test with QPCs compiled with and without guided decoding. + """ + # Export and compile QEfficient models + model_w_sampler_w_guided_decoding = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + num_hidden_layers=2, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 1024, + "include_guided_decoding": True, + }, + ) + model_w_sampler_wo_guided_decoding = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + num_hidden_layers=2, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 1024, + }, + ) + model_w_sampler_w_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_w_sampler_wo_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Generate texts from prompts + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) + sampling_params = { + "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(1024, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=1024), (full_batch_size, 1)).astype( + np.float32 + ), + } + model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + include_guided_decoding=True, + sampling_params={ + **sampling_params, + **{ + "token_bitmasks": np.random.choice( + [True, False], size=(full_batch_size, model_w_sampler_w_guided_decoding.model.config.vocab_size) + ) + }, + }, + ) + model_w_sampler_wo_guided_decoding_exec_info = model_w_sampler_wo_guided_decoding.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + sampling_params=sampling_params, + ) + assert ( + model_w_sampler_w_guided_decoding_exec_info.generated_ids + != model_w_sampler_wo_guided_decoding_exec_info.generated_ids + ), "Sampler outputs with and without guided decoding should not match" From 3789d5a36f4a268251ac26e9f1f3c3e907c77c55 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 20 Nov 2025 13:24:45 -0800 Subject: [PATCH 17/20] Update test_sampler.py Signed-off-by: quic-xiyushi --- tests/transformers/sampler/test_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index d31dfef37..ca4a3abef 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -195,7 +195,7 @@ def test_sampler_transform( @pytest.mark.on_qaic @pytest.mark.parametrize( "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", - sampler_transform_configs, + greedy_sampling_configs, ) def test_greedy_sampling( model: str, @@ -221,7 +221,7 @@ def test_greedy_sampling( additional_params["processor"] = AutoProcessor.from_pretrained(model) prompts = prompts[1] else: - additional_configs["num_hidden_layers"] = 2 + additional_configs["num_hidden_layers"] = 4 qeff_class = QEFFAutoModelForCausalLM spec_length -= 1 model_w_sampler = qeff_class.from_pretrained( From a24a55d4958c5f4e65d77399fdb4ab967a7f6379 Mon Sep 17 00:00:00 2001 From: sanising Date: Thu, 20 Nov 2025 16:13:18 -0600 Subject: [PATCH 18/20] Enable guided decoding in vlm generation Signed-off-by: sanising --- QEfficient/generation/vlm_generation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 6c028a12f..f7bc78cf6 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -92,6 +92,7 @@ def __init__( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ): """ @@ -111,6 +112,7 @@ def __init__( is_tlm: Target language model flag include_sampler: Enable on-device sampling (new feature) return_pdfs: Return probability distributions + include_guided_decoding: Enable guided decoding in on-device sampling sampling_params: Sampling parameters for on-device sampling """ # Validate required parameters @@ -134,6 +136,7 @@ def __init__( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, activate=False, # vision components need to be initialized first ) @@ -305,7 +308,7 @@ def _execute_chunked_prefill( lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] if self.include_sampler: - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if decode_batch_id is not None: lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: @@ -338,7 +341,7 @@ def _execute_chunked_prefill( if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): chunk_inputs[op] = lang_inputs[op] outputs = self._session.run(chunk_inputs) @@ -793,6 +796,7 @@ def generate_stream_tokens( is_tlm=self.is_tlm, include_sampler=self.include_sampler, return_pdfs=self.return_pdfs, + include_guided_decoding=self.include_guided_decoding, sampling_params=self.sampling_params, ) From 55e76e9bca1a92d2f078d6c4fa99e268c966a60c Mon Sep 17 00:00:00 2001 From: sanising Date: Thu, 20 Nov 2025 16:25:19 -0600 Subject: [PATCH 19/20] Fix bug Signed-off-by: sanising --- examples/performance/on_device_sampling.py | 6 +++++- tests/transformers/sampler/test_sampler.py | 23 +++++++++++----------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/performance/on_device_sampling.py b/examples/performance/on_device_sampling.py index 4b46ab78f..da9c5b43b 100644 --- a/examples/performance/on_device_sampling.py +++ b/examples/performance/on_device_sampling.py @@ -68,7 +68,11 @@ def main(args, **kwargs): # Ideally this should come from a logits processor like xgrammar, but for the sake of the # example, we generate a random bitmask sampling_params.update( - {"token_bitmasks": np.random.choice([True, False], size=(bs, qeff_model.model.config.vocab_size))} + { + "token_bitmasks": np.tile( + np.random.choice([True, False], size=(qeff_model.model.config.vocab_size,)), (bs, 1) + ) + } ) print("sampling_params:") pprint(sampling_params) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 4cbf7f392..9a075f04e 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -615,16 +615,14 @@ def test_guided_decoding( tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) np.random.seed(0) sampling_params = { - "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "top_ks": np.array(1024, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), - "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=1024), (full_batch_size, 1)).astype( - np.float32 - ), + "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32), } model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate( tokenizer=tokenizer, @@ -636,8 +634,9 @@ def test_guided_decoding( sampling_params={ **sampling_params, **{ - "token_bitmasks": np.random.choice( - [True, False], size=(full_batch_size, model_w_sampler_w_guided_decoding.model.config.vocab_size) + "token_bitmasks": np.tile( + np.random.choice([True, False], size=(model_w_sampler_w_guided_decoding.model.config.vocab_size,)), + (full_batch_size, 1), ) }, }, @@ -653,4 +652,4 @@ def test_guided_decoding( assert ( model_w_sampler_w_guided_decoding_exec_info.generated_ids != model_w_sampler_wo_guided_decoding_exec_info.generated_ids - ), "Sampler outputs with and without guided decoding should not match" + ).any(), "Sampler outputs with and without guided decoding should not match" From f9355d4bde9927d37464a2afe87d171fa8fe6b7f Mon Sep 17 00:00:00 2001 From: sanising Date: Thu, 20 Nov 2025 17:48:46 -0600 Subject: [PATCH 20/20] Fix bug Signed-off-by: sanising --- tests/transformers/sampler/test_sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9a075f04e..715ef35d5 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -639,6 +639,7 @@ def test_guided_decoding( (full_batch_size, 1), ) }, + **additional_params, }, ) model_w_sampler_wo_guided_decoding_exec_info = model_w_sampler_wo_guided_decoding.generate( @@ -648,6 +649,7 @@ def test_guided_decoding( include_sampler=True, return_pdfs=False, sampling_params=sampling_params, + **additional_params, ) assert ( model_w_sampler_w_guided_decoding_exec_info.generated_ids