diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 7da2300d6..4fb77f272 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -329,6 +329,7 @@ def cloud_ai_100_exec_kv( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ): """ @@ -356,6 +357,8 @@ def cloud_ai_100_exec_kv( next tokens. For Speculative Decoding Target Language Model, `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. + :include_guided_decoding (bool, default=False): If True, enables guided token-level filtering + during decoding. Only works when `include_sampler`=True. sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend. The dictionary should contain the following keys: `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, @@ -394,6 +397,7 @@ def cloud_ai_100_exec_kv( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) @@ -442,6 +446,7 @@ def __init__( is_tlm: Optional[int] = None, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, activate: bool = True, ) -> None: @@ -451,6 +456,7 @@ def __init__( self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs + self.include_guided_decoding = include_guided_decoding self.sampling_params = sampling_params self._qpc_path = qpc_path # Store qpc_path for later use @@ -461,7 +467,9 @@ def __init__( # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( - session_inputs=set(self._session.input_names), include_sampler=include_sampler + session_inputs=set(self._session.input_names), + include_sampler=include_sampler, + include_guided_decoding=include_guided_decoding, ) # Fetch the variables from the QPC @@ -628,7 +636,7 @@ def prepare_decode_inputs(self): decode_inputs["batch_index"] = self.batch_index if self.include_sampler: decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if self.batch_index is not None: decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()] else: @@ -795,7 +803,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: inputs["last_accepted_output_tokens"] = inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if decode_batch_id is not None: inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: @@ -1067,6 +1075,7 @@ def __init__( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( @@ -1082,6 +1091,7 @@ def __init__( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) self._full_batch_size = self._qaic_model.full_batch_size diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 5eb91d142..f7bc78cf6 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -36,6 +36,7 @@ write_io_files, ) from QEfficient.utils import LRUCache +from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger @@ -91,6 +92,7 @@ def __init__( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ): """ @@ -110,6 +112,7 @@ def __init__( is_tlm: Target language model flag include_sampler: Enable on-device sampling (new feature) return_pdfs: Return probability distributions + include_guided_decoding: Enable guided decoding in on-device sampling sampling_params: Sampling parameters for on-device sampling """ # Validate required parameters @@ -133,6 +136,7 @@ def __init__( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, activate=False, # vision components need to be initialized first ) @@ -303,6 +307,13 @@ def _execute_chunked_prefill( prefill_ccl_id = 0 lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + if self.include_sampler: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): + if decode_batch_id is not None: + lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + else: + lang_inputs[op] = self.sampling_params[op] + for i in range(num_chunks): input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] position_ids_slice = lang_inputs["position_ids"][ @@ -328,6 +339,11 @@ def _execute_chunked_prefill( chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"] + if self.include_sampler: + chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): + chunk_inputs[op] = lang_inputs[op] + outputs = self._session.run(chunk_inputs) if "image_idx_output" in outputs: @@ -780,6 +796,7 @@ def generate_stream_tokens( is_tlm=self.is_tlm, include_sampler=self.include_sampler, return_pdfs=self.return_pdfs, + include_guided_decoding=self.include_guided_decoding, sampling_params=self.sampling_params, ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cbff5be91..2a35d7281 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -8,7 +8,7 @@ import warnings from pathlib import Path from time import perf_counter -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import torch @@ -64,6 +64,7 @@ ) from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger +from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs class QEFFTransformersBase(QEFFBaseModel): @@ -919,16 +920,16 @@ def __init__( ---------- model : nn.Module The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : - Additional keyword arguments. `full_batch_size` is not supported here. - - Raises - ------ - NotImplementedError - If `full_batch_size` is provided. + Additional keyword arguments. """ if kwargs.pop("full_batch_size", None): - raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) self.model = model self.config = model.config @@ -937,7 +938,13 @@ def __init__( self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) self.continuous_batching = continuous_batching + self.lang_model.model.qaic_config = qaic_config self.input_shapes, self.output_names = None, None + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs) @property def model_name(self) -> str: @@ -1062,6 +1069,19 @@ def export( kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode ) output_names = self.model.get_output_names(kv_offload=True) + if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get( + "include_sampler", False + ): + logits_index = output_names["lang"].index("logits") + output_names["lang"][logits_index] = "next_tokens" + inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs( + example_inputs=inputs["lang"], + output_names=output_names["lang"], + dynamic_axes=dynamic_axes["lang"], + continuous_batching=self.continuous_batching, + vocab_size=self.config.vocab_size, + qaic_config=self.lang_model.model.qaic_config, + ) self.vision_model.export( inputs["vision"], @@ -1279,6 +1299,7 @@ def generate( device_ids: List[int] = None, runtime_ai100: bool = True, generation_len: Optional[int] = None, + **kwargs, ) -> Union[torch.Tensor, np.ndarray]: """ Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. @@ -1337,6 +1358,7 @@ def generate( full_batch_size=fbs, comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + **kwargs, ) # Call generate method @@ -1616,10 +1638,15 @@ def __init__( Raises ------ NotImplementedError - If `full_batch_size` is provided. + If `full_batch_size` is provided or `include_sampler` is True. """ if kwargs.pop("full_batch_size", None): + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if qaic_config is not None and qaic_config.pop("include_sampler", False): + raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) @@ -2230,6 +2257,8 @@ def from_pretrained( If True, uses the dual QPC approach (vision encoder KV offloaded). If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. @@ -2257,7 +2286,6 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( model, @@ -2336,6 +2364,8 @@ def __init__( - **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens. For Speculative Decoding Target Language Models, this is always True. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. **kwargs : Additional keyword arguments passed to the base class constructor. @@ -2438,6 +2468,8 @@ def from_pretrained( and ``return_pdfs=False`` for regular model. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. The values provided in ``top_ks`` tensor must be less than this maximum limit. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. *args : Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. @@ -2600,10 +2632,13 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs( example_inputs=example_inputs, output_names=output_names, dynamic_axes=dynamic_axes, + continuous_batching=self.continuous_batching, + vocab_size=self.model.config.vocab_size, + qaic_config=self.model.qaic_config, ) return self._export( @@ -2615,85 +2650,6 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = offload_pt_weights=kwargs.get("offload_pt_weights", True), ) - def get_sampling_inputs_and_outputs( - self, - example_inputs: Dict[str, torch.Tensor], - output_names: List[str], - dynamic_axes: Dict[str, Dict[int, str]], - ): - """ - Updates the example inputs, output names, and dynamic axes to include - parameters relevant for on-device sampling during ONNX export. - - Parameters - ---------- - example_inputs : Dict[str, torch.Tensor] - Current dictionary of example inputs. - output_names : List[str] - Current list of output names. - dynamic_axes : Dict[str, Dict[int, str]] - Current dictionary of dynamic axes configurations. - - Returns - ------- - Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] - Updated example inputs, output names, and dynamic axes including - sampling-related parameters. - """ - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - example_inputs["last_accepted_output_tokens"] = torch.zeros( - (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 - ) - dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} - - example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_repetition_penalty_buffer_RetainedState") - - example_inputs["repetition_penalties"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES - ) - dynamic_axes["repetition_penalties"] = {0: "batch_size"} - - example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_presence_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_presence_penalty_buffer_RetainedState") - - example_inputs["presence_penalties"] = ( - torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES - ) - dynamic_axes["presence_penalties"] = {0: "batch_size"} - - example_inputs["temperatures"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES - ) - dynamic_axes["temperatures"] = {0: "batch_size"} - - max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) - example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) - dynamic_axes["top_ks"] = {0: "batch_size"} - - example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS - dynamic_axes["top_ps"] = {0: "batch_size"} - - example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS - dynamic_axes["min_ps"] = {0: "batch_size"} - - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) - dynamic_axes["random_numbers"] = {0: "batch_size"} - - return example_inputs, output_names, dynamic_axes - def build_prefill_specialization( self, prefill_seq_len: int = 32, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 62a873b9e..75dc5a483 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -289,6 +289,7 @@ QEffGrok1MultiHeadAttention, ) from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternDecoderWrapper, QEffInternVisionEmbeddings, QEffInternVLModel, ) @@ -392,6 +393,7 @@ QEffQwen2_5_VLModel, QEffQwen2_5_VLTextModel, QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_DecoderWrapper, QEffQwen_2_5_vl_ForConditionalGeneration, ) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( @@ -707,10 +709,12 @@ class SamplerTransform: QEffGPTJForCausalLM, QEffGraniteForCausalLM, QEffGraniteMoeForCausalLM, + QEffInternDecoderWrapper, QEffLlamaForCausalLM, QEffMptForCausalLM, QEffPhi3ForCausalLM, QEffQwen2ForCausalLM, + QEffQwen_2_5_vl_DecoderWrapper, } @classmethod diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 96846e712..78457d46e 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -122,6 +124,9 @@ def sampler_forward( top_ps: Optional[torch.Tensor] = None, min_ps: Optional[torch.Tensor] = None, random_numbers: Optional[torch.Tensor] = None, + vision_embeds: Optional[torch.Tensor] = None, + image_idx: Optional[torch.Tensor] = None, + token_bitmasks: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -169,21 +174,41 @@ def sampler_forward( random_numbers (`torch.Tensor`, *optional*): Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. - """ - outputs = self.old_forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - batch_index=batch_index, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) + token_bitmasks (`torch.Tensor`, *optional*): + Boolean mask used to guide token-level filtering during decoding. Each + element of this tensor indicates whether the corresponding token should be + kept (1) or masked (0). Shape: (batch_size, vocab_size) + """ + if vision_embeds is not None: + forward_kwargs = dict( + input_ids=input_ids, + vision_embeds=vision_embeds, + position_ids=position_ids, + image_idx=image_idx, + past_key_values=past_key_values, + ) + if batch_index is not None: + forward_kwargs["batch_index"] = batch_index + + logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs) + outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) + if position_ids.dim() == 3: # For models using m-rope + position_ids = position_ids[0] + else: + outputs = self.old_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) logits = outputs.get("logits", None) assert logits is not None, f"{self.model.__class__.__name__} does not return logits." @@ -197,6 +222,13 @@ def sampler_forward( batch_index = torch.arange(batch_size).view(-1, 1) batch_index_reshaped = batch_index.view(-1) + + # Guided decoding + if token_bitmasks is not None and (token_bitmasks != 1).any(): + assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding" + # Mask logits where token_bitmasks is 0 with -inf + logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min) + # Prefill past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path( input_ids=input_ids, @@ -230,7 +262,9 @@ def sampler_forward( return SamplerOutput( probs=None, next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) @@ -300,9 +334,8 @@ def sampler_forward( ) # (batch_size, spec_length, vocab_size) # Random Sampling - topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids) gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick - y = topk_probs_asc + gumbel_noise + y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids) random_samples_indices = torch.argmax(y, dim=1, keepdim=True) random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1) @@ -314,7 +347,9 @@ def sampler_forward( return SamplerOutput( probs=probs, next_tokens=next_tokens, # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py index 6fb1b326f..42f6d5825 100644 --- a/QEfficient/utils/sampler_utils.py +++ b/QEfficient/utils/sampler_utils.py @@ -5,13 +5,18 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Set +from typing import Dict, List, Optional, Set +import torch + +from QEfficient.utils import constants from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger -def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[bool] = None) -> bool: +def validate_sampler_inputs( + session_inputs: Set[str], include_sampler: Optional[bool] = None, include_guided_decoding: Optional[bool] = None +) -> bool: """ Validates whether the `QAICInferenceSession` inputs match inputs required for on-device sampling. @@ -28,7 +33,7 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ValueError if partial support is detected or if user intent conflicts with QPC capabilities. """ - sampler_inputs = Constants.SAMPLER_INPUTS + sampler_inputs = Constants.SAMPLER_INPUTS | ({"token_bitmasks"} if include_guided_decoding else set()) count = len(sampler_inputs & session_inputs) session_includes_sampler = True @@ -56,3 +61,93 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ) return session_includes_sampler + + +def get_sampling_inputs_and_outputs( + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + continuous_batching: bool, + vocab_size: int, + qaic_config: Dict, +): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + continuous_batching : bool + Whether this model will be used for continuous batching in the future. + vocab_size: int + Vocabulary size for this model. + qaic_config : Dict + QAIC config dictionary. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + example_inputs["last_accepted_output_tokens"] = torch.zeros( + (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 + ) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + if qaic_config.get("include_guided_decoding", False): + example_inputs["token_bitmasks"] = torch.zeros((bs, vocab_size), dtype=torch.bool) + dynamic_axes["token_bitmasks"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes diff --git a/examples/performance/on_device_sampling.py b/examples/performance/on_device_sampling.py index 6cc72b715..da9c5b43b 100644 --- a/examples/performance/on_device_sampling.py +++ b/examples/performance/on_device_sampling.py @@ -21,6 +21,7 @@ def main(args, **kwargs): include_sampler = None return_pdfs = None max_top_k_ids = None + include_guided_decoding = None sampling_params = None bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size if args.override_qaic_config is not None: @@ -28,6 +29,8 @@ def main(args, **kwargs): if include_sampler is not None: return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) + np.random.seed(int(args.random_number)) + include_guided_decoding = args.override_qaic_config.get("aic_include_guided_decoding", None) == "true" sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -36,7 +39,9 @@ def main(args, **kwargs): "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), "min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1), - "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype( + np.float32 + ), } qaic_config = { k: v @@ -44,13 +49,12 @@ def main(args, **kwargs): "include_sampler": include_sampler, "return_pdfs": return_pdfs, "max_top_k_ids": max_top_k_ids, + "include_guided_decoding": include_guided_decoding, }.items() if v is not None } print("qaic_config:") pprint(qaic_config) - print("sampling_params:") - pprint(sampling_params) # Load model with On Device Sampler enabled qeff_model = AutoModelForCausalLM.from_pretrained( @@ -60,6 +64,19 @@ def main(args, **kwargs): ) print(f"{args.model_name} optimized for AI 100 \n", qeff_model) + if include_guided_decoding: + # Ideally this should come from a logits processor like xgrammar, but for the sake of the + # example, we generate a random bitmask + sampling_params.update( + { + "token_bitmasks": np.tile( + np.random.choice([True, False], size=(qeff_model.model.config.vocab_size,)), (bs, 1) + ) + } + ) + print("sampling_params:") + pprint(sampling_params) + # Compile the model for inference generated_qpc_path = qeff_model.compile( prefill_seq_len=args.prompt_len, @@ -88,6 +105,7 @@ def main(args, **kwargs): generation_len=args.generation_len, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) @@ -106,14 +124,14 @@ def main(args, **kwargs): --num-cores 16 \ --mxint8-kv-cache \ --mxfp6-matmul \ - --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \ --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 2. For non-continuous batching: python3.10 examples/on_device_sampling.py \ @@ -126,14 +144,34 @@ def main(args, **kwargs): --num-cores 16 \ --mxint8-kv-cache \ --mxfp6-matmul \ - --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \ + --repetition-penalty 1.9 \ + --presence-penalty 0.8 \ + --temperature 0.67 \ + --top-k 54 \ + --top-p 0.89 \ + --min-p 0.6 \ + --random-number 26 + + 3. With guided decoding: + python3.10 examples/on_device_sampling.py \ + --model-name 'meta-llama/Llama-3.1-8B' \ + --prompt-len 128 \ + --ctx-len 256 \ + --generation-len 20 \ + --full-batch-size 2 \ + --device-group [0,1,2,3] \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:true" \ --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 """ parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9335e1d91..9a075f04e 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -5,12 +5,13 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Union import numpy as np import pytest +from transformers import AutoProcessor -from QEfficient import QEFFAutoModelForCausalLM +from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import load_hf_tokenizer from QEfficient.utils.constants import Constants @@ -24,6 +25,20 @@ 20, # generation_len 2, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "Qwen/Qwen2.5-VL-3B-Instruct", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] greedy_sampling_configs = [ @@ -35,6 +50,20 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "Qwen/Qwen2.5-VL-3B-Instruct", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] random_sampling_configs = [ @@ -46,23 +75,63 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm + ), + # pytest.param( + # "Qwen/Qwen2.5-VL-3B-Instruct", # model + # ( + # ["https://picsum.photos/id/237/536/354"] * 2, + # ["Can you describe the image in detail."] * 2, + # ), # images and prompts + # 128, # prefill_seq_len + # 4096, # ctx_len + # 20, # generation_len + # 2, # full_batch_size + # None, # spec_length + # True, # is_vlm + # ), +] +guided_decoding_configs = [ + pytest.param( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # model + Constants.INPUT_STR * 4, # prompts + 32, # prefill_seq_len + 64, # ctx_len + 20, # generation_len + 4, # full_batch_size + 1, # spec_length + False, # is_vlm ), + # pytest.param( + # "Qwen/Qwen2.5-VL-3B-Instruct", # model + # ( + # ["https://picsum.photos/id/237/536/354"] * 2, + # ["Can you describe the image in detail."] * 2, + # ), # images and prompts + # 128, # prefill_seq_len + # 4096, # ctx_len + # 20, # generation_len + # 2, # full_batch_size + # None, # spec_length + # True, # is_vlm + # ), ] @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", sampler_transform_configs, ) def test_sampler_transform( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the @@ -70,48 +139,82 @@ def test_sampler_transform( next tokens and/or probability distributions. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + else: + additional_configs["num_hidden_layers"] = 2 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + }, + **additional_configs, + ) + model_w_sampler_w_guided_decoding = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, + "include_guided_decoding": True, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, + ) + model_w_sampler_qpc_path = model_w_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, ) - model_w_sampler_qpc_path: str = model_w_sampler.compile( + model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) - model_wo_sampler_qpc_path: str = model_wo_sampler.compile( + model_wo_sampler_qpc_path = model_wo_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) + if is_vlm: + model_w_sampler_qpc_path = model_w_sampler_qpc_path[1] + model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding_qpc_path[1] + model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1] # Init qaic session model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) + model_w_sampler_w_guided_decoding_session = QAICInferenceSession(model_w_sampler_w_guided_decoding_qpc_path) model_wo_sampler_session = QAICInferenceSession(model_wo_sampler_qpc_path) # Skip inputs/outputs buffers @@ -119,6 +222,12 @@ def test_sampler_transform( model_w_sampler_session.skip_buffers( set([x for x in model_w_sampler_session.output_names if x.endswith("_RetainedState")]) ) + model_w_sampler_w_guided_decoding_session.skip_buffers( + set([x for x in model_w_sampler_w_guided_decoding_session.input_names if x.startswith("past_")]) + ) + model_w_sampler_w_guided_decoding_session.skip_buffers( + set([x for x in model_w_sampler_w_guided_decoding_session.output_names if x.endswith("_RetainedState")]) + ) model_wo_sampler_session.skip_buffers( set([x for x in model_wo_sampler_session.input_names if x.startswith("past_")]) ) @@ -132,47 +241,67 @@ def test_sampler_transform( assert input_name in model_w_sampler_session.input_names, ( f"Sampler input {input_name} not found in QPC compiled with On Device Sampler" ) + assert input_name in model_w_sampler_w_guided_decoding_session.input_names, ( + f"Sampler input {input_name} not found in QPC compiled with On Device Sampler and Guided Decoding" + ) assert input_name not in model_wo_sampler_session.input_names, ( f"Sampler input {input_name} found in QPC compiled without On Device Sampler" ) + assert "token_bitmasks" in model_w_sampler_w_guided_decoding_session.input_names, ( + "Sampler input token_bitmasks not found in QPC compiled with On Device Sampler and Guided Decoding" + ) @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", greedy_sampling_configs, ) def test_greedy_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ - Test greedy sampling with QPC compiled with and without On Device Sampling. + Test greedy sampling with QPCs compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + additional_configs["num_hidden_layers"] = 4 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -180,7 +309,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -190,7 +319,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -211,8 +340,9 @@ def test_greedy_sampling( "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -221,6 +351,7 @@ def test_greedy_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts and ids @@ -233,25 +364,37 @@ def test_greedy_sampling( @pytest.mark.on_qaic -@pytest.mark.skip @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", random_sampling_configs, ) def test_random_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ - Test random sampling with QPC compiled with and without On Device Sampling. + Test random sampling with QPCs compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -259,14 +402,16 @@ def test_random_sampling( "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -274,7 +419,7 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -284,13 +429,14 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) # Generate texts from prompts tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) model_w_sampler_exec_info = model_w_sampler.generate( tokenizer=tokenizer, prompts=prompts, @@ -301,12 +447,15 @@ def test_random_sampling( "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype( + np.float32 + ), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -315,36 +464,37 @@ def test_random_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts golden_texts = { - "w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both", + "w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over", "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", } golden_ids = { "w_sampler": [ [ - 21380, + 319, + 3615, 322, - 590, - 25448, - 2927, - 29892, - 19963, - 2654, - 29879, - 470, - 3708, - 2701, - 313, - 29902, - 508, - 30010, - 29873, - 505, - 963, - 1716, + 306, + 626, + 263, + 3005, + 295, + 749, + 9227, + 1058, + 12355, + 267, + 304, + 26987, + 278, + 3186, + 29889, + 2973, + 975, ] ], "wo_sampler": [ @@ -385,3 +535,121 @@ def test_random_sampling( assert (model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"]).all(), ( "Without sampler generated ids do not match" ) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", + guided_decoding_configs, +) +def test_guided_decoding( + model: str, + prompts: List[str], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + spec_length: int, + is_vlm: bool, +): + """ + Test with QPCs compiled with and without guided decoding. + """ + # Export and compile QEfficient models + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + additional_configs["num_hidden_layers"] = 2 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler_w_guided_decoding = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 1024, + "include_guided_decoding": True, + }, + **additional_configs, + ) + model_w_sampler_wo_guided_decoding = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 1024, + }, + **additional_configs, + ) + model_w_sampler_w_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_w_sampler_wo_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Generate texts from prompts + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) + sampling_params = { + "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(1024, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32), + } + model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + include_guided_decoding=True, + sampling_params={ + **sampling_params, + **{ + "token_bitmasks": np.tile( + np.random.choice([True, False], size=(model_w_sampler_w_guided_decoding.model.config.vocab_size,)), + (full_batch_size, 1), + ) + }, + }, + ) + model_w_sampler_wo_guided_decoding_exec_info = model_w_sampler_wo_guided_decoding.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + sampling_params=sampling_params, + ) + assert ( + model_w_sampler_w_guided_decoding_exec_info.generated_ids + != model_w_sampler_wo_guided_decoding_exec_info.generated_ids + ).any(), "Sampler outputs with and without guided decoding should not match"