From aafe4005876eb52d04285d9f013b7e755c9ad695 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Mon, 17 Nov 2025 14:16:16 +0000 Subject: [PATCH 1/9] pushed all changes for incoperating subfunction in CausalLM Signed-off-by: abhishek-singh591 --- QEfficient/base/modeling_qeff.py | 38 +++- QEfficient/base/onnx_transforms.py | 166 +++++++++++++++++- QEfficient/transformers/cache_utils.py | 42 +++-- .../transformers/models/modeling_auto.py | 88 ++++++++-- .../transformers/models/pytorch_transforms.py | 26 +++ QEfficient/utils/_utils.py | 1 + QEfficient/utils/hash_utils.py | 2 +- QEfficient/utils/patches.py | 115 ++++++++++++ tests/transformers/test_subfunction.py | 69 ++++++++ 9 files changed, 515 insertions(+), 32 deletions(-) create mode 100644 QEfficient/utils/patches.py create mode 100644 tests/transformers/test_subfunction.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..827249d43 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,6 +8,7 @@ import gc import inspect import logging +import re import shutil import subprocess import warnings @@ -18,10 +19,14 @@ import onnx import torch -from QEfficient.base.onnx_transforms import OnnxTransform +from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile +from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc +from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.cache_utils import InvalidIndexProvider +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import ( constants, create_json, @@ -32,6 +37,7 @@ hash_dict_params, load_json, ) +from QEfficient.utils.patches import apply_torch_patches, undo_torch_patches logger = logging.getLogger(__name__) @@ -53,7 +59,7 @@ class QEFFBaseModel(ABC): def _transform_names(cls) -> List[str]: return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms] - def __init__(self, model: torch.nn.Module, **kwargs) -> None: + def __init__(self, model: torch.nn.Module, use_subfunctions: bool = False, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) @@ -64,6 +70,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: (arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0] ) or None + self.use_subfunctions = use_subfunctions # Flag for checking if weights are offloaded self._is_weights_offloaded: bool = False @@ -179,6 +186,7 @@ def _export( onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, + use_subfunctions: bool = False, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -243,7 +251,21 @@ def _export( input_names.append(param) try: + # Initialize the registry with your custom ops export_kwargs = {} if export_kwargs is None else export_kwargs + CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm) + CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter) + CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather) + if use_subfunctions: + warnings.warn( + "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." + ) + apply_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = True + output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names] + export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model) + self._onnx_transforms.append(CustomOpTransform) + torch.onnx.export( self.model, (example_inputs,), @@ -252,15 +274,16 @@ def _export( output_names=output_names, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, + do_constant_folding=True, **export_kwargs, ) logger.info("PyTorch export successful") - _ = self._offload_model_weights(offload_pt_weights) model = onnx.load(tmp_onnx_path, load_external_data=False) transform_kwargs = { "onnx_base_dir": str(tmp_onnx_dir), + "temp_onnx_path": tmp_onnx_path, "model_name": self.model_name, } if onnx_transform_kwargs is not None: @@ -284,6 +307,10 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) + if use_subfunctions: + undo_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = False + self.onnx_path = onnx_path return onnx_path @@ -300,6 +327,7 @@ def _compile( num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -325,9 +353,9 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ - if onnx_path is None and self.onnx_path is None: - self.export() + if onnx_path is None and self.onnx_path is None: + self.export(use_subfunctions=use_subfunctions) onnx_path = Path(onnx_path or self.onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 61b5c00f6..65287426a 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -5,9 +5,12 @@ # # ---------------------------------------------------------------------------- -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np +import onnx +import onnxslim +import torch from onnx import ModelProto, external_data_helper, numpy_helper @@ -99,3 +102,164 @@ def apply( current_file_size = tsize external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") return model, transformed + + +class OnnxSlimTransform(OnnxTransform): + """ + Applies onnx-slim transformations on the given ONNX graph. + """ + + @classmethod + def apply( + cls, + model: ModelProto, + *, + onnx_base_dir: Optional[str] = None, + **kwargs, + ) -> Tuple[ModelProto, bool]: + """ + :param enable_onnx_slim_transform: If True, applies onnx-slim transformations. + :param temp_onnx_path: Path to save the slimmed ONNX model. + """ + transformed = False + onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) + temp_onnx_path = kwargs.get("temp_onnx_path", None) + if not temp_onnx_path: + err_str = "temp_onnx_path is required for onnx-slim transform." + raise RuntimeError(err_str) + if onnx_slim_transform: + transformed = True + slimmed_model = onnxslim.slim(model) + onnx.save(slimmed_model, temp_onnx_path) + return slimmed_model, transformed + return model, transformed + + +class CustomOpTransform(OnnxTransform): + """ + Transform to register custom operations and add their function protos to the ONNX model. + """ + + # Registry of custom operations + _custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func) + + @classmethod + def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any): + """Register a custom operation.""" + cls._custom_ops[op_name] = (func_class, onnxscript_func) + + @classmethod + def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]: + """ + Apply custom op registration and add function protos to the model. + + :param model: The ONNX model to transform + :param opset_version: ONNX opset version for symbolic registration + :returns: Transformed model and success flag + """ + transformed = False + + # Register all custom op symbolic functions with torch.onnx + for op_name, (func_class, _) in cls._custom_ops.items(): + if hasattr(func_class, "symbolic"): + torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version) + + # Add function protos for custom ops that are used in the model + used_protos = cls._get_function_protos_for_model(model) + + for proto in used_protos: + # Check if proto already exists to avoid duplicates + proto_name = proto.name + if not any(func.name == proto_name for func in model.functions): + model.functions.append(proto) + transformed = True + + return model, transformed + + @classmethod + def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]: + """Get function protos for custom ops that are actually used in the model.""" + used_protos = [] + + # Get all node op_types in the model + used_op_types = set() + for node in model.graph.node: + used_op_types.add(node.op_type) + + # Also check function calls + for func in model.functions: + for node in func.node: + used_op_types.add(node.op_type) + + # Check which custom ops are actually used + for op_name, (func_class, onnxscript_func) in cls._custom_ops.items(): + # Check if the custom op is referenced in the model + if cls._is_custom_op_used(model, op_name, used_op_types): + proto = onnxscript_func.to_function_proto() + used_protos.append(proto) + + return used_protos + + @classmethod + def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool: + """Check if a custom op is used in the model.""" + # Check if the op_name appears in node op_types + if op_name in used_op_types: + return True + + # Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm") + custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" + if custom_op_pattern in used_op_types: + return True + + # Heuristic checks based on op type + if "RMSNorm" in op_name: + # Check if any RMSNorm-related ops are present + return any("RMSNorm" in op_type for op_type in used_op_types) + + if "Ctx" in op_name: + # Check if Gather/Scatter operations are present (indicating KV cache usage) + return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types) + + return False + + +class RenameFunctionOutputsTransform(OnnxTransform): + """ + Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns. + """ + + @classmethod + def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: + """ + Rename function outputs in decoder layer nodes. + + :param model: The ONNX model to transform + :returns: Transformed model and boolean indicating whether transform was applied + """ + graph = model.graph + op_type_to_func_map = {func.name: func for func in model.functions} + decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"] + transformed = False + model_graph_outputs = [val.name for val in model.graph.output] + layer_index = 0 + for node in graph.node: + if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns): + func = op_type_to_func_map.get(node.op_type) + if func is None: + continue + + for i, out_name in enumerate(func.output): + if "_InternalRetainedState" in out_name: + transformed = True + tmp = node.output[i] + if "key" in out_name: + new_name = f"past_key.{layer_index}_RetainedState" + elif "value" in out_name: + new_name = f"past_value.{layer_index}_RetainedState" + node.output[i] = new_name + # Update graph output name if it exists + if tmp in model_graph_outputs: + model.graph.output[model_graph_outputs.index(tmp)].name = new_name + layer_index = layer_index + 1 + return model, transformed diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 5452589f6..292fe0487 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -24,6 +24,33 @@ ) +class InvalidIndexProvider: + SUBFUNC_ENABLED = False + + @classmethod + def enable_subfunc(cls): + cls.SUBFUNC_ENABLED = True + + @classmethod + def _get_invalid_idx_value(cls): + """ + Get the appropriate invalid index value for CtxGather operations. + + For ONNX export with functions, we use 0 to avoid INT32_MAX constants + that cause issues when functions are inlined at runtime. + + Returns: + int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise) + """ + if torch.onnx.is_in_onnx_export(): + if cls.SUBFUNC_ENABLED: + return 0 + else: + return torch.iinfo(torch.int32).max + else: + return 0 + + class QEffDynamicLayer(DynamicLayer): def read_only(self, cache_kwargs): """ @@ -46,10 +73,7 @@ def read_only(self, cache_kwargs): gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) @@ -143,10 +167,7 @@ def update( gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: @@ -419,10 +440,7 @@ def update( ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5f1ec51e6..cab72243d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -27,7 +27,11 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.base.onnx_transforms import ( + FP16ClipTransform, + RenameFunctionOutputsTransform, + SplitTensorsTransform, +) from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import ( @@ -315,7 +319,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = False) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -350,6 +354,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_subfunctions=use_subfunctions, ) def compile( @@ -362,6 +367,7 @@ def compile( num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -595,7 +601,15 @@ def __init__(self, model: nn.modules, **kwargs): self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + def export( + self, + inputs, + output_names, + dynamic_axes, + export_dir=None, + offload_pt_weights=True, + use_subfunctions: bool = False, + ): """ Exports the vision encoder component to ONNX format. @@ -618,7 +632,12 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the vision encoder. """ return self._export( - inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_subfunctions=use_subfunctions, ) def compile( @@ -631,6 +650,7 @@ def compile( mdp_ts_num_devices, aic_num_cores, custom_io, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -737,7 +757,15 @@ def __init__(self, model, **kwargs): self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + def export( + self, + inputs, + output_names, + dynamic_axes, + export_dir=None, + offload_pt_weights=True, + use_subfunctions: bool = False, + ): """ Exports the language decoder component to ONNX format. @@ -760,7 +788,12 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the language decoder. """ return self._export( - inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_subfunctions=use_subfunctions, ) def compile( @@ -773,6 +806,7 @@ def compile( mdp_ts_num_devices, aic_num_cores, custom_io, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -973,6 +1007,7 @@ def qpc_path(self): def export( self, export_dir: Optional[str] = None, + use_subfunctions: bool = False, **kwargs, ) -> str: """ @@ -1043,6 +1078,7 @@ def compile( mxint8_kv_cache: bool = False, skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -1154,7 +1190,9 @@ def compile( if (self.vision_model.onnx_path is None and vision_onnx_path is None) or ( self.lang_model.onnx_path is None and lang_onnx_path is None ): - self.export() + self.export( + use_subfunctions=use_subfunctions, + ) # TODO this hould be removed once the continous batching is supported for all the models. compiler_options.pop("continuous_batching", None) @@ -1624,6 +1662,7 @@ def from_pretrained( def export( self, export_dir: Optional[str] = None, + use_subfunctions: bool = False, **kwargs, ) -> str: """ @@ -1644,7 +1683,13 @@ def export( inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode) dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + return self._export( + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + use_subfunctions=use_subfunctions, + ) def compile( self, @@ -1662,6 +1707,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -2232,7 +2278,11 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): SplitGateUpWeightsTransform, KVCacheExternalModuleMapperTransform, ] - _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + _onnx_transforms = [ + FP16ClipTransform, + RenameFunctionOutputsTransform, + SplitTensorsTransform, + ] def __init__( self, @@ -2423,7 +2473,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = False, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2532,6 +2582,8 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_subfunctions=use_subfunctions, + offload_pt_weights=kwargs.get("offload_pt_weights", True), ) def get_sampling_inputs_and_outputs( @@ -2742,6 +2794,7 @@ def compile( mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -3135,7 +3188,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = False) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -3156,7 +3209,13 @@ def export(self, export_dir: Optional[str] = None) -> str: inputs = self.model.get_dummy_inputs() dynamic_axes = self.model.get_onnx_dynamic_axes() output_names = self.model.get_output_names() - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + return self._export( + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + use_subfunctions=use_subfunctions, + ) def compile( self, @@ -3174,6 +3233,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -3499,7 +3559,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k def get_model_config(self) -> dict: return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = False) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. @@ -3525,6 +3585,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_subfunctions=use_subfunctions, ) def compile( @@ -3537,6 +3598,7 @@ def compile( num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, + use_subfunctions: bool = False, **compiler_options, ) -> str: """ diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..62a873b9e 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -821,3 +821,29 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu model = PooledModel(model, pooling_method) warnings.warn("Pooling is applied to the model.") return model, transformed + + +def get_decoder_layer_classes_for_export(model: nn.Module) -> set: + """ + Dynamically determine which DecoderLayer classes should be exported as functions + based on the model's architecture using the existing KVCacheTransform mapping. + """ + # Define patterns that identify decoder layer classes + DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] + + # Get all QEff classes that are decoder layers from the existing mapping + decoder_layer_classes = set() + + for original_class, qeff_class in KVCacheTransform._module_mapping.items(): + # Check if the QEff class name contains decoder layer patterns + qeff_class_name = qeff_class.__name__ + if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS): + decoder_layer_classes.add(qeff_class) + + # Filter to only include classes that are actually used in the current model + model_decoder_classes = set() + for module in model.modules(): + if module.__class__ in decoder_layer_classes: + model_decoder_classes.add(module.__class__) + + return model_decoder_classes diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index d58f54952..a507a6a80 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -566,6 +566,7 @@ def wrapper(self, *args, **kwargs): dynamic_axes=all_args.get("dynamic_axes"), export_kwargs=all_args.get("export_kwargs", None), onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), + use_subfunctions=all_args.get("use_subfunctions", False), ) export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) kwargs["export_dir"] = export_dir diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index b6b38b8b4..c99ae1bac 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -55,7 +55,7 @@ def create_export_hash(**kwargs): export_params = {} export_params["output_names"] = kwargs.get("output_names") export_params["dynamic_axes"] = kwargs.get("dynamic_axes") - + export_params["use_subfunctions"] = kwargs.get("use_subfunctions", False) export_hash_params["export_params"] = export_params export_kwargs = kwargs.get("export_kwargs") diff --git a/QEfficient/utils/patches.py b/QEfficient/utils/patches.py new file mode 100644 index 000000000..0b9b37afa --- /dev/null +++ b/QEfficient/utils/patches.py @@ -0,0 +1,115 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""Monkey patches for torch.onnx.utils to fix ONNX export issues.""" + +import torch +import torch.onnx.utils as onnx_utils +from torch import _C + +# Store original references before patching +_original_setup_trace_module_map = onnx_utils._setup_trace_module_map +_original_get_module_attributes = getattr(onnx_utils, "_get_module_attributes", None) + + +def _setup_trace_module_map_patched( + model, + export_modules_as_functions, +): + """Patched version of _setup_trace_module_map that fixes onnx_attrs type mismatch.""" + + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + # FIX: use empty dict to avoid type mismatch + onnx_attrs = {} + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name(torch.typename(type(_m)), _unqualified_variable_name(_n)) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _get_module_attributes(module): + """Helper function to get module attributes safely.""" + import typing + + import torch.nn + + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + _C._jit_onnx_log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def apply_torch_patches(): + """Apply monkey patches for ONNX export.""" + onnx_utils._setup_trace_module_map = _setup_trace_module_map_patched + if hasattr(onnx_utils, "_get_module_attributes"): + onnx_utils._get_module_attributes = _get_module_attributes + + +def undo_torch_patches(): + """Undo monkey patches and restore original functions.""" + onnx_utils._setup_trace_module_map = _original_setup_trace_module_map + if _original_get_module_attributes: + onnx_utils._get_module_attributes = _original_get_module_attributes diff --git a/tests/transformers/test_subfunction.py b/tests/transformers/test_subfunction.py new file mode 100644 index 000000000..9baf3cf52 --- /dev/null +++ b/tests/transformers/test_subfunction.py @@ -0,0 +1,69 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +torch.manual_seed(42) + +configs = [ + ("gpt2", 256, 2, 4, 128, 512, 127, {}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] + +model_kwargs = {"attn_implementation": "eager"} +config_ids = [x.model_type for x in configs] + + +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_subfunction_vs_nonsubfunction(config, tmp_path): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(config.model_type) + model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) + + with_sub_func_onnx = model_0_0.export(tmp_path, use_subfunctions=True, offload_pt_weights=False) + hash_0_0 = model_0_0.export_hash + + without_sub_func_onnx = model_0_0.export(tmp_path, use_subfunctions=False) + hash_0_1 = model_0_0.export_hash + + assert hash_0_0 != hash_0_1 + + compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params) + generation_00 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer) + + model_0_0.compile(onnx_path=without_sub_func_onnx, **compile_params) + generation_01 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer) + + assert generation_00.generated_texts == generation_01.generated_texts From f0413d6455f003f7aa1b20976d4da6910dd19e18 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Tue, 18 Nov 2025 05:49:23 +0000 Subject: [PATCH 2/9] Changed flag name from use_subfunctions to use_onnx_subfunctions Signed-off-by: abhishek-singh591 --- QEfficient/base/modeling_qeff.py | 14 +++--- .../transformers/models/modeling_auto.py | 48 +++++++++---------- QEfficient/utils/_utils.py | 2 +- QEfficient/utils/hash_utils.py | 2 +- tests/transformers/test_subfunction.py | 5 +- 5 files changed, 35 insertions(+), 36 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 827249d43..bb9ac57dc 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -59,7 +59,7 @@ class QEFFBaseModel(ABC): def _transform_names(cls) -> List[str]: return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms] - def __init__(self, model: torch.nn.Module, use_subfunctions: bool = False, **kwargs) -> None: + def __init__(self, model: torch.nn.Module, use_onnx_subfunctions: bool = False, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) @@ -70,7 +70,7 @@ def __init__(self, model: torch.nn.Module, use_subfunctions: bool = False, **kwa (arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0] ) or None - self.use_subfunctions = use_subfunctions + self.use_onnx_subfunctions = use_onnx_subfunctions # Flag for checking if weights are offloaded self._is_weights_offloaded: bool = False @@ -186,7 +186,7 @@ def _export( onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -256,7 +256,7 @@ def _export( CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm) CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter) CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather) - if use_subfunctions: + if use_onnx_subfunctions: warnings.warn( "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." ) @@ -307,7 +307,7 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) - if use_subfunctions: + if use_onnx_subfunctions: undo_torch_patches() InvalidIndexProvider.SUBFUNC_ENABLED = False @@ -327,7 +327,7 @@ def _compile( num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -355,7 +355,7 @@ def _compile( """ if onnx_path is None and self.onnx_path is None: - self.export(use_subfunctions=use_subfunctions) + self.export(use_onnx_subfunctions=use_onnx_subfunctions) onnx_path = Path(onnx_path or self.onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cab72243d..26c80f9e4 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -319,7 +319,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -354,7 +354,7 @@ def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = Fals output_names, dynamic_axes, export_dir=export_dir, - use_subfunctions=use_subfunctions, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -367,7 +367,7 @@ def compile( num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -608,7 +608,7 @@ def export( dynamic_axes, export_dir=None, offload_pt_weights=True, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, ): """ Exports the vision encoder component to ONNX format. @@ -637,7 +637,7 @@ def export( dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_subfunctions=use_subfunctions, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -650,7 +650,7 @@ def compile( mdp_ts_num_devices, aic_num_cores, custom_io, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -764,7 +764,7 @@ def export( dynamic_axes, export_dir=None, offload_pt_weights=True, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, ): """ Exports the language decoder component to ONNX format. @@ -793,7 +793,7 @@ def export( dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_subfunctions=use_subfunctions, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -806,7 +806,7 @@ def compile( mdp_ts_num_devices, aic_num_cores, custom_io, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -1007,7 +1007,7 @@ def qpc_path(self): def export( self, export_dir: Optional[str] = None, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **kwargs, ) -> str: """ @@ -1078,7 +1078,7 @@ def compile( mxint8_kv_cache: bool = False, skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -1191,7 +1191,7 @@ def compile( self.lang_model.onnx_path is None and lang_onnx_path is None ): self.export( - use_subfunctions=use_subfunctions, + use_onnx_subfunctions=use_onnx_subfunctions, ) # TODO this hould be removed once the continous batching is supported for all the models. @@ -1662,7 +1662,7 @@ def from_pretrained( def export( self, export_dir: Optional[str] = None, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **kwargs, ) -> str: """ @@ -1688,7 +1688,7 @@ def export( output_names, dynamic_axes, export_dir=export_dir, - use_subfunctions=use_subfunctions, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -1707,7 +1707,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -2473,7 +2473,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = False, **kwargs) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2582,7 +2582,7 @@ def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = Fals output_names, dynamic_axes, export_dir=export_dir, - use_subfunctions=use_subfunctions, + use_onnx_subfunctions=use_onnx_subfunctions, offload_pt_weights=kwargs.get("offload_pt_weights", True), ) @@ -2794,7 +2794,7 @@ def compile( mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -3188,7 +3188,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -3214,7 +3214,7 @@ def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = Fals output_names, dynamic_axes, export_dir=export_dir, - use_subfunctions=use_subfunctions, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -3233,7 +3233,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -3559,7 +3559,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k def get_model_config(self) -> dict: return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. @@ -3585,7 +3585,7 @@ def export(self, export_dir: Optional[str] = None, use_subfunctions: bool = Fals output_names, dynamic_axes, export_dir=export_dir, - use_subfunctions=use_subfunctions, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -3598,7 +3598,7 @@ def compile( num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, - use_subfunctions: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index a507a6a80..1fb0311eb 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -566,7 +566,7 @@ def wrapper(self, *args, **kwargs): dynamic_axes=all_args.get("dynamic_axes"), export_kwargs=all_args.get("export_kwargs", None), onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), - use_subfunctions=all_args.get("use_subfunctions", False), + use_onnx_subfunctions=all_args.get("use_onnx_subfunctions", False), ) export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) kwargs["export_dir"] = export_dir diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index c99ae1bac..b940dbe50 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -55,7 +55,7 @@ def create_export_hash(**kwargs): export_params = {} export_params["output_names"] = kwargs.get("output_names") export_params["dynamic_axes"] = kwargs.get("dynamic_axes") - export_params["use_subfunctions"] = kwargs.get("use_subfunctions", False) + export_params["use_onnx_subfunctions"] = kwargs.get("use_onnx_subfunctions", False) export_hash_params["export_params"] = export_params export_kwargs = kwargs.get("export_kwargs") diff --git a/tests/transformers/test_subfunction.py b/tests/transformers/test_subfunction.py index 9baf3cf52..34e149b4e 100644 --- a/tests/transformers/test_subfunction.py +++ b/tests/transformers/test_subfunction.py @@ -51,10 +51,10 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path): tokenizer = AutoTokenizer.from_pretrained(config.model_type) model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) - with_sub_func_onnx = model_0_0.export(tmp_path, use_subfunctions=True, offload_pt_weights=False) + with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) hash_0_0 = model_0_0.export_hash - without_sub_func_onnx = model_0_0.export(tmp_path, use_subfunctions=False) + without_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=False) hash_0_1 = model_0_0.export_hash assert hash_0_0 != hash_0_1 @@ -65,5 +65,4 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path): model_0_0.compile(onnx_path=without_sub_func_onnx, **compile_params) generation_01 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer) - assert generation_00.generated_texts == generation_01.generated_texts From 01a969656150059b4a1f2052f7c28e59a9cc5ea9 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Tue, 18 Nov 2025 06:16:24 +0000 Subject: [PATCH 3/9] Minor fixes Signed-off-by: abhishek-singh591 --- QEfficient/base/onnx_transforms.py | 33 ------------------------------ 1 file changed, 33 deletions(-) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 65287426a..9a60eaa91 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -8,8 +8,6 @@ from typing import Any, Dict, List, Optional, Tuple import numpy as np -import onnx -import onnxslim import torch from onnx import ModelProto, external_data_helper, numpy_helper @@ -104,37 +102,6 @@ def apply( return model, transformed -class OnnxSlimTransform(OnnxTransform): - """ - Applies onnx-slim transformations on the given ONNX graph. - """ - - @classmethod - def apply( - cls, - model: ModelProto, - *, - onnx_base_dir: Optional[str] = None, - **kwargs, - ) -> Tuple[ModelProto, bool]: - """ - :param enable_onnx_slim_transform: If True, applies onnx-slim transformations. - :param temp_onnx_path: Path to save the slimmed ONNX model. - """ - transformed = False - onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) - temp_onnx_path = kwargs.get("temp_onnx_path", None) - if not temp_onnx_path: - err_str = "temp_onnx_path is required for onnx-slim transform." - raise RuntimeError(err_str) - if onnx_slim_transform: - transformed = True - slimmed_model = onnxslim.slim(model) - onnx.save(slimmed_model, temp_onnx_path) - return slimmed_model, transformed - return model, transformed - - class CustomOpTransform(OnnxTransform): """ Transform to register custom operations and add their function protos to the ONNX model. From 219230abb74d66d022231a141b782a5dd5602a02 Mon Sep 17 00:00:00 2001 From: quic-akuruvil Date: Wed, 19 Nov 2025 09:07:08 +0530 Subject: [PATCH 4/9] Fix for token during inference (#622) Fix for this JIRA from Imagine team Signed-off-by: Ann Kuruvilla Signed-off-by: abhishek-singh591 --- examples/gemma3_example/gemma3_mm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gemma3_example/gemma3_mm.py b/examples/gemma3_example/gemma3_mm.py index e090148f7..ca82b2120 100644 --- a/examples/gemma3_example/gemma3_mm.py +++ b/examples/gemma3_example/gemma3_mm.py @@ -105,5 +105,5 @@ ) inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) output = qeff_model.generate(inputs=inputs, generation_len=100) - print(tokenizer.batch_decode(output.generated_ids)) + print(tokenizer.batch_decode(output.generated_ids, skip_special_tokens=True)) print(output) From 6daa209a4d1e996da4fe063ccddbb2fa16f77378 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 19 Nov 2025 06:42:30 +0000 Subject: [PATCH 5/9] Addressed all the comments Signed-off-by: abhishek-singh591 --- QEfficient/base/modeling_qeff.py | 14 ++- QEfficient/base/onnx_transforms.py | 115 ++++++++++-------- .../transformers/models/modeling_auto.py | 41 ++++++- QEfficient/utils/hash_utils.py | 3 +- .../utils/{patches.py => torch_patches.py} | 0 tests/transformers/test_subfunction.py | 5 +- 6 files changed, 115 insertions(+), 63 deletions(-) rename QEfficient/utils/{patches.py => torch_patches.py} (100%) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index bb9ac57dc..277c6baf3 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -19,7 +19,7 @@ import onnx import torch -from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform +from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, RenameFunctionOutputsTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc @@ -37,7 +37,7 @@ hash_dict_params, load_json, ) -from QEfficient.utils.patches import apply_torch_patches, undo_torch_patches +from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ class QEFFBaseModel(ABC): def _transform_names(cls) -> List[str]: return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms] - def __init__(self, model: torch.nn.Module, use_onnx_subfunctions: bool = False, **kwargs) -> None: + def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) @@ -70,7 +70,6 @@ def __init__(self, model: torch.nn.Module, use_onnx_subfunctions: bool = False, (arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0] ) or None - self.use_onnx_subfunctions = use_onnx_subfunctions # Flag for checking if weights are offloaded self._is_weights_offloaded: bool = False @@ -264,8 +263,10 @@ def _export( InvalidIndexProvider.SUBFUNC_ENABLED = True output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names] export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model) + self._onnx_transforms.append(RenameFunctionOutputsTransform) self._onnx_transforms.append(CustomOpTransform) + # import pdb; pdb.set_trace() torch.onnx.export( self.model, (example_inputs,), @@ -274,7 +275,6 @@ def _export( output_names=output_names, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, - do_constant_folding=True, **export_kwargs, ) logger.info("PyTorch export successful") @@ -283,7 +283,6 @@ def _export( model = onnx.load(tmp_onnx_path, load_external_data=False) transform_kwargs = { "onnx_base_dir": str(tmp_onnx_dir), - "temp_onnx_path": tmp_onnx_path, "model_name": self.model_name, } if onnx_transform_kwargs is not None: @@ -310,6 +309,8 @@ def _export( if use_onnx_subfunctions: undo_torch_patches() InvalidIndexProvider.SUBFUNC_ENABLED = False + self._onnx_transforms.pop() + self._onnx_transforms.pop() self.onnx_path = onnx_path return onnx_path @@ -356,6 +357,7 @@ def _compile( if onnx_path is None and self.onnx_path is None: self.export(use_onnx_subfunctions=use_onnx_subfunctions) + onnx_path = Path(onnx_path or self.onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 9a60eaa91..a30f378cb 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -5,7 +5,7 @@ # # ---------------------------------------------------------------------------- -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np import torch @@ -107,11 +107,11 @@ class CustomOpTransform(OnnxTransform): Transform to register custom operations and add their function protos to the ONNX model. """ - # Registry of custom operations - _custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func) + # Registry of custom operations: op_name -> (func_class, onnxscript_func) + _custom_ops: Dict[str, Tuple[Any, Any]] = {} @classmethod - def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any): + def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any) -> None: """Register a custom operation.""" cls._custom_ops[op_name] = (func_class, onnxscript_func) @@ -120,9 +120,9 @@ def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple """ Apply custom op registration and add function protos to the model. - :param model: The ONNX model to transform - :param opset_version: ONNX opset version for symbolic registration - :returns: Transformed model and success flag + :param model: The ONNX model to transform. + :param opset_version: ONNX opset version for symbolic registration. + :returns: (Transformed model, success flag). """ transformed = False @@ -131,62 +131,70 @@ def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple if hasattr(func_class, "symbolic"): torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version) - # Add function protos for custom ops that are used in the model - used_protos = cls._get_function_protos_for_model(model) + # Gather function names and all nodes (graph + function nodes) + func_names: Set[str] = {func.name for func in model.functions} + all_nodes = list(model.graph.node) + for func in model.functions: + all_nodes.extend(func.node) + + # Collect used op types + used_op_types: Set[str] = {node.op_type for node in all_nodes} + + # Precompute heuristic flags + has_rmsnorm = any("RMSNorm" in op_type for op_type in used_op_types) + has_ctx_ops = any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types) + + # Get function protos for custom ops used in the model + used_protos = cls._get_function_protos_for_model(used_op_types, has_rmsnorm, has_ctx_ops) + # Append new function protos if not already present for proto in used_protos: - # Check if proto already exists to avoid duplicates - proto_name = proto.name - if not any(func.name == proto_name for func in model.functions): + if proto.name not in func_names: model.functions.append(proto) transformed = True return model, transformed @classmethod - def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]: - """Get function protos for custom ops that are actually used in the model.""" - used_protos = [] - - # Get all node op_types in the model - used_op_types = set() - for node in model.graph.node: - used_op_types.add(node.op_type) - - # Also check function calls - for func in model.functions: - for node in func.node: - used_op_types.add(node.op_type) - - # Check which custom ops are actually used - for op_name, (func_class, onnxscript_func) in cls._custom_ops.items(): - # Check if the custom op is referenced in the model - if cls._is_custom_op_used(model, op_name, used_op_types): - proto = onnxscript_func.to_function_proto() - used_protos.append(proto) + def _get_function_protos_for_model(cls, used_op_types: Set[str], has_rmsnorm: bool, has_ctx_ops: bool) -> List[Any]: + """ + Get function protos for custom ops that are actually used in the model. + :param used_op_types: Set of op types used in the model. + :param has_rmsnorm: Flag indicating if RMSNorm-related ops are present. + :param has_ctx_ops: Flag indicating if context-related ops are present. + :returns: List of ONNX function protos. + """ + used_protos: List[Any] = [] + for op_name, (_, onnxscript_func) in cls._custom_ops.items(): + if cls._is_custom_op_used(op_name, used_op_types, has_rmsnorm, has_ctx_ops): + used_protos.append(onnxscript_func.to_function_proto()) return used_protos @classmethod - def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool: - """Check if a custom op is used in the model.""" - # Check if the op_name appears in node op_types + def _is_custom_op_used(cls, op_name: str, used_op_types: Set[str], has_rmsnorm: bool, has_ctx_ops: bool) -> bool: + """ + Check if a custom op is used in the model. + + :param op_name: Name of the custom op. + :param used_op_types: Set of op types used in the model. + :param has_rmsnorm: Precomputed RMSNorm presence flag. + :param has_ctx_ops: Precomputed context ops presence flag. + :returns: True if the custom op is used, False otherwise. + """ if op_name in used_op_types: return True - # Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm") - custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" - if custom_op_pattern in used_op_types: + # Check for domain-specific ops + if f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" in used_op_types: return True - # Heuristic checks based on op type - if "RMSNorm" in op_name: - # Check if any RMSNorm-related ops are present - return any("RMSNorm" in op_type for op_type in used_op_types) + # Heuristic checks + if "RMSNorm" in op_name and has_rmsnorm: + return True - if "Ctx" in op_name: - # Check if Gather/Scatter operations are present (indicating KV cache usage) - return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types) + if "Ctx" in op_name and has_ctx_ops: + return True return False @@ -208,7 +216,10 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: op_type_to_func_map = {func.name: func for func in model.functions} decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"] transformed = False - model_graph_outputs = [val.name for val in model.graph.output] + + # Create a dict mapping output name to its index for quick lookup + model_graph_outputs_map = {val.name: idx for idx, val in enumerate(model.graph.output)} + layer_index = 0 for node in graph.node: if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns): @@ -219,14 +230,18 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: for i, out_name in enumerate(func.output): if "_InternalRetainedState" in out_name: transformed = True - tmp = node.output[i] + original_output_name = node.output[i] + + # Generate new name based on key/value if "key" in out_name: new_name = f"past_key.{layer_index}_RetainedState" elif "value" in out_name: new_name = f"past_value.{layer_index}_RetainedState" node.output[i] = new_name + # Update graph output name if it exists - if tmp in model_graph_outputs: - model.graph.output[model_graph_outputs.index(tmp)].name = new_name - layer_index = layer_index + 1 + if original_output_name in model_graph_outputs_map: + idx = model_graph_outputs_map[original_output_name] + model.graph.output[idx].name = new_name + layer_index += 1 return model, transformed diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 26c80f9e4..fc9d77757 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -29,7 +29,6 @@ from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import ( FP16ClipTransform, - RenameFunctionOutputsTransform, SplitTensorsTransform, ) from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform @@ -331,6 +330,8 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. Returns ------- @@ -394,6 +395,8 @@ def compile( Number of cores to use for compilation. mxfp6_matmul : bool, optional Use MXFP6 compression for weights. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. **compiler_options : dict Additional compiler options for QAIC or QNN compilers. These are passed directly to the underlying compilation command. @@ -437,6 +440,7 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -625,6 +629,8 @@ def export( Directory path where the exported ONNX graph will be saved. Default is None. offload_pt_weights : bool, optional If True, PyTorch weights will be offloaded after export. Default is True. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. Returns ------- @@ -674,6 +680,8 @@ def compile( Number of cores to use for compilation. custom_io : Dict[str, str] Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. **compiler_options : Additional compiler options passed to the underlying compilation command. @@ -691,6 +699,7 @@ def compile( mdp_ts_num_devices=mdp_ts_num_devices, aic_num_cores=aic_num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -781,6 +790,8 @@ def export( Directory path where the exported ONNX graph will be saved. Default is None. offload_pt_weights : bool, optional If True, PyTorch weights will be offloaded after export. Default is True. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. Returns ------- @@ -830,6 +841,8 @@ def compile( Number of cores to use for compilation. custom_io : Dict[str, str] Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. **compiler_options : Additional compiler options passed to the underlying compilation command. @@ -847,6 +860,7 @@ def compile( mdp_ts_num_devices=mdp_ts_num_devices, aic_num_cores=aic_num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -1020,6 +1034,8 @@ def export( ---------- export_dir : str, optional Directory path where the exported ONNX graphs will be saved. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. **kwargs : Additional keyword arguments. @@ -1118,6 +1134,8 @@ def compile( If True, skips compilation of the vision encoder. Default is False. skip_lang : bool, optional If True, skips compilation of the language decoder. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -1210,6 +1228,7 @@ def compile( aic_num_cores=num_cores, custom_io=custom_io_vision, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -1238,6 +1257,7 @@ def compile( aic_num_cores=num_cores, custom_io=custom_io_lang, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) return self.qpc_path @@ -1743,6 +1763,8 @@ def compile( Use MXINT8 compression for KV cache. Default is False. num_speculative_tokens : int, optional Not supported for this model; must be None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -1815,6 +1837,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) return self.qpc_path @@ -2280,7 +2303,6 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): ] _onnx_transforms = [ FP16ClipTransform, - RenameFunctionOutputsTransform, SplitTensorsTransform, ] @@ -2486,7 +2508,8 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. - + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. Returns ------- str @@ -2836,6 +2859,8 @@ def compile( prefill_only : bool, optional If True, compiles only for the prefill stage. If False, compiles only for the decode stage. If None, compiles for both stages. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -2997,6 +3022,7 @@ def compile( num_speculative_tokens=num_speculative_tokens, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -3200,6 +3226,8 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. Returns ------- @@ -3275,6 +3303,8 @@ def compile( Not yet supported for this model. num_speculative_tokens : int, optional Not yet supported for this model. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. **compiler_options : dict Additional compiler options for QAIC. @@ -3342,6 +3372,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -3565,6 +3596,8 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = ``Optional`` Args: :export_dir (str, optional): The directory path to store ONNX-graph. + :use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. Returns: :str: Path of the generated ``ONNX`` graph. @@ -3614,6 +3647,7 @@ def compile( :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. :num_cores (int): Number of cores used to compile the model. :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. + :use_onnx_subfunctions: bool, optional: whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. :compiler_options (dict, optional): Additional compiler options. For QAIC Compiler: Extra arguments for qaic-exec can be passed. @@ -3646,6 +3680,7 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index b940dbe50..948b72e6a 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -55,7 +55,8 @@ def create_export_hash(**kwargs): export_params = {} export_params["output_names"] = kwargs.get("output_names") export_params["dynamic_axes"] = kwargs.get("dynamic_axes") - export_params["use_onnx_subfunctions"] = kwargs.get("use_onnx_subfunctions", False) + if kwargs.get("use_onnx_subfunctions"): + export_params["use_onnx_subfunctions"] = True export_hash_params["export_params"] = export_params export_kwargs = kwargs.get("export_kwargs") diff --git a/QEfficient/utils/patches.py b/QEfficient/utils/torch_patches.py similarity index 100% rename from QEfficient/utils/patches.py rename to QEfficient/utils/torch_patches.py diff --git a/tests/transformers/test_subfunction.py b/tests/transformers/test_subfunction.py index 34e149b4e..36cfc0ce5 100644 --- a/tests/transformers/test_subfunction.py +++ b/tests/transformers/test_subfunction.py @@ -7,7 +7,7 @@ import pytest import torch -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM @@ -46,10 +46,9 @@ @pytest.mark.parametrize("config", configs, ids=config_ids) def test_subfunction_vs_nonsubfunction(config, tmp_path): - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(config.model_type) model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) + # model_0_0 = QEFFAutoModelForCausalLM.from_pretrained(config.model_type) with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) hash_0_0 = model_0_0.export_hash From 50fda72465612ac92700aa20fd90d992f2a00a55 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 19 Nov 2025 10:02:12 +0000 Subject: [PATCH 6/9] Rebased and some other fixes Signed-off-by: abhishek-singh591 --- QEfficient/base/modeling_qeff.py | 10 ++-------- QEfficient/base/onnx_transforms.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 277c6baf3..72f5c050e 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -22,8 +22,6 @@ from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, RenameFunctionOutputsTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile -from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc -from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.cache_utils import InvalidIndexProvider from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export @@ -252,9 +250,6 @@ def _export( try: # Initialize the registry with your custom ops export_kwargs = {} if export_kwargs is None else export_kwargs - CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm) - CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter) - CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather) if use_onnx_subfunctions: warnings.warn( "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." @@ -266,7 +261,6 @@ def _export( self._onnx_transforms.append(RenameFunctionOutputsTransform) self._onnx_transforms.append(CustomOpTransform) - # import pdb; pdb.set_trace() torch.onnx.export( self.model, (example_inputs,), @@ -309,8 +303,8 @@ def _export( if use_onnx_subfunctions: undo_torch_patches() InvalidIndexProvider.SUBFUNC_ENABLED = False - self._onnx_transforms.pop() - self._onnx_transforms.pop() + self._onnx_transforms.remove(CustomOpTransform) + self._onnx_transforms.remove(RenameFunctionOutputsTransform) self.onnx_path = onnx_path return onnx_path diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index a30f378cb..3b6717c0c 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -11,6 +11,9 @@ import torch from onnx import ModelProto, external_data_helper, numpy_helper +from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc +from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc + class OnnxTransform: """ @@ -107,8 +110,11 @@ class CustomOpTransform(OnnxTransform): Transform to register custom operations and add their function protos to the ONNX model. """ - # Registry of custom operations: op_name -> (func_class, onnxscript_func) - _custom_ops: Dict[str, Tuple[Any, Any]] = {} + _custom_ops: Dict[str, Tuple[Any, Any]] = { + "CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm), + "CtxScatterFunc": (CtxScatterFunc, CtxScatter), + "CtxGatherFunc": (CtxGatherFunc, CtxGather), + } @classmethod def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any) -> None: From f75a764a234d778950f58f2d2369d9db6006e1e4 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 19 Nov 2025 10:07:36 +0000 Subject: [PATCH 7/9] Rebased and some other fixes Signed-off-by: abhishek-singh591 --- .../transformers/models/modeling_auto.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index fc9d77757..a1eece0f7 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -331,7 +331,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -396,7 +396,7 @@ def compile( mxfp6_matmul : bool, optional Use MXFP6 compression for weights. Default is False. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. These are passed directly to the underlying compilation command. @@ -630,7 +630,7 @@ def export( offload_pt_weights : bool, optional If True, PyTorch weights will be offloaded after export. Default is True. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -681,7 +681,7 @@ def compile( custom_io : Dict[str, str] Custom I/O configurations for the compiler. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : Additional compiler options passed to the underlying compilation command. @@ -791,7 +791,7 @@ def export( offload_pt_weights : bool, optional If True, PyTorch weights will be offloaded after export. Default is True. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -842,7 +842,7 @@ def compile( custom_io : Dict[str, str] Custom I/O configurations for the compiler. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : Additional compiler options passed to the underlying compilation command. @@ -1035,7 +1035,7 @@ def export( export_dir : str, optional Directory path where the exported ONNX graphs will be saved. Default is None. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **kwargs : Additional keyword arguments. @@ -1135,7 +1135,7 @@ def compile( skip_lang : bool, optional If True, skips compilation of the language decoder. Default is False. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -1764,7 +1764,7 @@ def compile( num_speculative_tokens : int, optional Not supported for this model; must be None. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -2509,7 +2509,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- str @@ -2860,7 +2860,7 @@ def compile( If True, compiles only for the prefill stage. If False, compiles only for the decode stage. If None, compiles for both stages. Default is None. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -3227,7 +3227,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -3304,7 +3304,7 @@ def compile( num_speculative_tokens : int, optional Not yet supported for this model. use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC. @@ -3597,7 +3597,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = ``Optional`` Args: :export_dir (str, optional): The directory path to store ONNX-graph. :use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns: :str: Path of the generated ``ONNX`` graph. @@ -3647,7 +3647,7 @@ def compile( :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. :num_cores (int): Number of cores used to compile the model. :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. - :use_onnx_subfunctions: bool, optional: whether to enable ONNX subfunctions during export. Using subfunctions can improve model compilation efficiency and execution performance on hardware. Defaults to False. + :use_onnx_subfunctions: bool, optional: whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False :compiler_options (dict, optional): Additional compiler options. For QAIC Compiler: Extra arguments for qaic-exec can be passed. From 13fe095f21d78cfe6dc3828c81dbefa5d156259e Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 19 Nov 2025 10:50:28 +0000 Subject: [PATCH 8/9] Changed Custom_ops transform logic now adding all custom_ops proto. Signed-off-by: abhishek-singh591 --- QEfficient/base/onnx_transforms.py | 67 +++--------------------------- 1 file changed, 5 insertions(+), 62 deletions(-) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 3b6717c0c..7ebe6bce5 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -5,7 +5,7 @@ # # ---------------------------------------------------------------------------- -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, Optional, Tuple import numpy as np import torch @@ -124,7 +124,7 @@ def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any) @classmethod def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]: """ - Apply custom op registration and add function protos to the model. + Apply custom op registration and add all function protos to the model. :param model: The ONNX model to transform. :param opset_version: ONNX opset version for symbolic registration. @@ -137,73 +137,16 @@ def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple if hasattr(func_class, "symbolic"): torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version) - # Gather function names and all nodes (graph + function nodes) - func_names: Set[str] = {func.name for func in model.functions} - all_nodes = list(model.graph.node) - for func in model.functions: - all_nodes.extend(func.node) + func_names = {func.name for func in model.functions} - # Collect used op types - used_op_types: Set[str] = {node.op_type for node in all_nodes} - - # Precompute heuristic flags - has_rmsnorm = any("RMSNorm" in op_type for op_type in used_op_types) - has_ctx_ops = any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types) - - # Get function protos for custom ops used in the model - used_protos = cls._get_function_protos_for_model(used_op_types, has_rmsnorm, has_ctx_ops) - - # Append new function protos if not already present - for proto in used_protos: + for _, onnxscript_func in cls._custom_ops.values(): + proto = onnxscript_func.to_function_proto() if proto.name not in func_names: model.functions.append(proto) transformed = True return model, transformed - @classmethod - def _get_function_protos_for_model(cls, used_op_types: Set[str], has_rmsnorm: bool, has_ctx_ops: bool) -> List[Any]: - """ - Get function protos for custom ops that are actually used in the model. - - :param used_op_types: Set of op types used in the model. - :param has_rmsnorm: Flag indicating if RMSNorm-related ops are present. - :param has_ctx_ops: Flag indicating if context-related ops are present. - :returns: List of ONNX function protos. - """ - used_protos: List[Any] = [] - for op_name, (_, onnxscript_func) in cls._custom_ops.items(): - if cls._is_custom_op_used(op_name, used_op_types, has_rmsnorm, has_ctx_ops): - used_protos.append(onnxscript_func.to_function_proto()) - return used_protos - - @classmethod - def _is_custom_op_used(cls, op_name: str, used_op_types: Set[str], has_rmsnorm: bool, has_ctx_ops: bool) -> bool: - """ - Check if a custom op is used in the model. - - :param op_name: Name of the custom op. - :param used_op_types: Set of op types used in the model. - :param has_rmsnorm: Precomputed RMSNorm presence flag. - :param has_ctx_ops: Precomputed context ops presence flag. - :returns: True if the custom op is used, False otherwise. - """ - if op_name in used_op_types: - return True - - # Check for domain-specific ops - if f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" in used_op_types: - return True - - # Heuristic checks - if "RMSNorm" in op_name and has_rmsnorm: - return True - - if "Ctx" in op_name and has_ctx_ops: - return True - - return False - class RenameFunctionOutputsTransform(OnnxTransform): """ From 50a29174d6014609ac98727384efa77f74301083 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 19 Nov 2025 14:37:08 +0000 Subject: [PATCH 9/9] Made Minor fixes Signed-off-by: abhishek-singh591 --- QEfficient/peft/auto.py | 5 ++++- QEfficient/peft/lora/auto.py | 3 ++- QEfficient/transformers/models/modeling_auto.py | 8 +++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 592c0c1d3..99d64cc2f 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -245,7 +245,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Export the model with the active adapter to ONNX format. @@ -286,6 +286,7 @@ def export(self, export_dir: Optional[str] = None) -> str: export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights onnx_transform_kwargs={"adapter_name": self.model.active_adapter}, export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, ) def compile( @@ -300,6 +301,7 @@ def compile( num_cores: int = 16, mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -367,6 +369,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py index 8196cd769..64fa3f61c 100644 --- a/QEfficient/peft/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -327,7 +327,7 @@ def _init_adapter_model(self): # load_weight to model self._load_adapter_weights_to_model() - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: """ Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``. @@ -387,6 +387,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, ) def generate( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index a1eece0f7..cbff5be91 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1069,9 +1069,15 @@ def export( dynamic_axes["vision"], export_dir=export_dir, offload_pt_weights=False, + use_onnx_subfunctions=use_onnx_subfunctions, ) self.lang_model.export( - inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir, offload_pt_weights=True + inputs["lang"], + output_names["lang"], + dynamic_axes["lang"], + export_dir=export_dir, + offload_pt_weights=True, + use_onnx_subfunctions=use_onnx_subfunctions, ) return self.onnx_path