From f3419ec29eaa4a282e44f3da9df1ac5a85c23d4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 12:11:30 +0100 Subject: [PATCH 1/6] refactor qwen attention --- .../patches/_patch_transformers_qwen2_5.py | 140 +++++++++++------- 1 file changed, 87 insertions(+), 53 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index c6905d23..c9758225 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -1,5 +1,7 @@ +import os from typing import Callable, Optional import torch +import torch.nn.functional as F from .patch_helper import _is_torchdynamo_exporting from ._patch_transformers_attention import patched_sdpa_attention_forward @@ -11,10 +13,9 @@ except ImportError: patch_qwen2_5 = False -if patch_qwen2_5: - import torch.nn.functional as F +strategy_for_attention_in_qwen_2_5 = os.environ.get("QWEN25ATTENTION", "BIGMASK") - use_loop_for_attention_in_qwen_2_5 = False +if patch_qwen2_5: class patched_Qwen2_5_VLForConditionalGeneration: _PATCHES_ = ["prepare_inputs_for_generation"] @@ -345,58 +346,35 @@ def forward( self.config._attn_implementation ] - if ( - self.config._attn_implementation == "flash_attention_2" - and _is_torchdynamo_exporting() - ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - attn_output = torch.onnx.ops.symbolic( - "custom::qwen25_attention", - ( - query_states, - key_states, - value_states, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - torch.tensor(self.scaling, dtype=torch.float32), - ), - dtype=query_states.dtype, - shape=( - key_states.shape[0], - value_states.shape[1], - max_seqlen, - value_states.shape[-1], - ), - version=1, - ) - elif self.config._attn_implementation == "flash_attention_2": - # Flash Attention 2: Use cu_seqlens for variable length attention - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=None, - scaling=self.scaling, - dropout=0.0 if not self.training else self.attention_dropout, - cu_seq_lens_q=cu_seqlens, - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) - elif _is_torchdynamo_exporting(): - if ( + if _is_torchdynamo_exporting(): + if self.config._attn_implementation == "flash_attention_2": + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output = torch.onnx.ops.symbolic( + "custom::qwen25_flash_attention", + ( + query_states, + key_states, + value_states, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + torch.tensor(self.scaling, dtype=torch.float32), + ), + dtype=query_states.dtype, + shape=( + query_states.shape[0], # batch_size + query_states.shape[2], # sequence_length (total patches) + query_states.shape[1], # num_heads + query_states.shape[3], # head_size + ), + version=1, + ) + elif ( attention_interface is transformers.integrations.sdpa_attention.sdpa_attention_forward + and strategy_for_attention_in_qwen_2_5 == "LOOPMHA" ): - attention_interface = patched_sdpa_attention_forward - - if use_loop_for_attention_in_qwen_2_5: def _iteration(start_end, query_states, key_states, value_states): return patched_Qwen2_5_VLVisionAttentionOneIteration.forward( @@ -428,7 +406,11 @@ def _iteration(start_end, query_states, key_states, value_states): # starts_ends, query_states, key_states, value_states), tuple(), # ) attn_output = torch.cat(attn_outputs, dim=1) - else: + elif ( + attention_interface + is transformers.integrations.sdpa_attention.sdpa_attention_forward + and strategy_for_attention_in_qwen_2_5 == "BIGMASK" + ): # make square mask indices = torch.arange( cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device @@ -455,6 +437,58 @@ def _iteration(start_end, query_states, key_states, value_states): is_causal=False, **kwargs, ) + elif ( + attention_interface + is transformers.integrations.sdpa_attention.sdpa_attention_forward + and strategy_for_attention_in_qwen_2_5 == "PACKED" + ): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output = torch.onnx.ops.symbolic( + "custom::qwen25_packed_attention", + ( + query_states, + key_states, + value_states, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + torch.tensor(self.scaling, dtype=torch.float32), + ), + dtype=query_states.dtype, + shape=( + query_states.shape[0], # batch_size + query_states.shape[2], # sequence_length (total patches) + query_states.shape[1], # num_heads + query_states.shape[3], # head_size + ), + version=1, + ) + else: + raise NotImplementedError( + f"Not export strategy for strategy_for_attention_in_qwen_2_5=" + f"{strategy_for_attention_in_qwen_2_5!r}, " + f"(use QWEN25ATTENTION to change it), and attention_interface=" + f"{attention_interface!r} (use sdpa)" + ) + elif self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) else: # Other implementations: Process each chunk separately lengths = cu_seqlens[1:] - cu_seqlens[:-1] From 4579cf9bf764814b95075962f63a181575465a13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 12:13:11 +0100 Subject: [PATCH 2/6] doc --- CHANGELOGS.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index a07f0688..7811c8b7 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,8 @@ Change Logs 0.8.3 +++++ -* :pr:`310`: split patches into multiple files +* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime +* :pr:`310`: splits patches into multiple files * :pr:`308`: add option --save_ep to dump the exported program as well as torch input * :pr:`304`, :pr:`306`: improves side-by-side comparison, creates command line sbs From e786ca01ad45149d3f204399e1c7c17872fc8c14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 15:19:59 +0100 Subject: [PATCH 3/6] onnx_plug --- _doc/api/export/index.rst | 1 + _doc/api/export/onnx_plug.rst | 7 + _unittests/ut_export/test_onnx_plug.py | 103 +++++++++++ onnx_diagnostic/export/api.py | 46 ++++- onnx_diagnostic/export/control_flow.py | 2 +- onnx_diagnostic/export/onnx_plug.py | 247 +++++++++++++++++++++++++ 6 files changed, 399 insertions(+), 7 deletions(-) create mode 100644 _doc/api/export/onnx_plug.rst create mode 100644 _unittests/ut_export/test_onnx_plug.py create mode 100644 onnx_diagnostic/export/onnx_plug.py diff --git a/_doc/api/export/index.rst b/_doc/api/export/index.rst index f9858c23..ce546228 100644 --- a/_doc/api/export/index.rst +++ b/_doc/api/export/index.rst @@ -8,6 +8,7 @@ onnx_diagnostic.export api control_flow dynamic_shapes + onnx_plug shape_helper validate diff --git a/_doc/api/export/onnx_plug.rst b/_doc/api/export/onnx_plug.rst new file mode 100644 index 00000000..23d2eb98 --- /dev/null +++ b/_doc/api/export/onnx_plug.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.export.onnx_plug +================================ + +.. automodule:: onnx_diagnostic.export.onnx_plug + :members: + :no-undoc-members: diff --git a/_unittests/ut_export/test_onnx_plug.py b/_unittests/ut_export/test_onnx_plug.py new file mode 100644 index 00000000..9d990e4c --- /dev/null +++ b/_unittests/ut_export/test_onnx_plug.py @@ -0,0 +1,103 @@ +import unittest +import onnx.helper as oh +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx +from onnx_diagnostic.export.api import to_onnx + + +class TestOnnxPlus(ExtTestCase): + def test_onnx_plug_verify(self): + def _test_customadd(x, y): + return x + y + + def _test_customadd_shape(x, y): + return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype) + + def make_function_proto(): + return oh.make_function( + "onnx_plug", + "_test_customadd", + ["x", "y"], + ["z"], + [oh.make_node("Add", ["x", "y"], ["z"])], + opset_imports=[oh.make_opsetid("", 22)], + ) + + rep = EagerDirectReplacementWithOnnx( + _test_customadd, _test_customadd_shape, make_function_proto(), 2, 1 + ) + + x = torch.randn((3, 4), dtype=torch.float32) + y = torch.randn((3, 1), dtype=torch.float32) + self.assertEqualArray(_test_customadd(x, y), x + y) + res = rep.verify(x, y) + self.assertEqualAny(res.eager_outputs, (x + y,)) + self.assertEqual(len(res.diffs), 1) + self.assertEqual(res.diffs[0]["abs"], 0) + + def test_onnx_plug_export(self): + def _test_customsub(x, y): + return x - y + + def _test_customsub_shape(x, y): + return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype) + + def make_function_proto(): + return oh.make_function( + "onnx_plug", + "_test_customsub", + ["x", "y"], + ["z"], + [oh.make_node("Sub", ["x", "y"], ["z"])], + opset_imports=[oh.make_opsetid("", 22)], + ) + + class Model(torch.nn.Module): + def forward(self, x): + y = x.sum(axis=1, keepdim=True) + d = torch.ops.onnx_plug._test_customsub(x, y) + return torch.abs(d) + + replacements = [ + EagerDirectReplacementWithOnnx( + _test_customsub, _test_customsub_shape, make_function_proto(), 2, 1 + ) + ] + + x = torch.randn((3, 4), dtype=torch.float32) + model = Model() + expected = model(x) + ds = ({0: "d1", 1: "d2"},) + ep = torch.export.export(model, (x,), dynamic_shapes=self.use_dyn_not_str(ds)) + self.assertIn("torch.ops.onnx_plug._test_customsub.default", str(ep)) + got = ep.module()(x) + self.assertEqualArray(expected, got) + + with self.subTest(exporter="custom"): + onx = to_onnx( + model, + (x,), + dynamic_shapes=ds, + exporter="custom", + onnx_plugs=replacements, + target_opset=22, + ) + self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,)) + + with self.subTest(exporter="onnx-dynamo"): + onx = to_onnx( + model, + (x,), + dynamic_shapes=ds, + exporter="onnx-dynamo", + onnx_plugs=replacements, + target_opset=22, + ) + self.assert_onnx_disc( + "test_onnx_plug_export_onnx_dynamo", onx.model_proto, model, (x,) + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index ebfbd7e9..92a409bf 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch +from .onnx_plug import EagerDirectReplacementWithOnnx def to_onnx( @@ -18,6 +19,7 @@ def to_onnx( save_ep: Optional[str] = None, optimize: bool = True, use_control_flow_dispatcher: bool = False, + onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None, ) -> Any: """ Common API for exporters. By default, the models are optimized to use the @@ -41,6 +43,7 @@ def to_onnx( :param optimize: optimizes the model :param use_control_flow_dispatcher: use the dispatcher created to supported custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`) + :param onnx_plugs: the code was modified to replace some parts with onnx translation :return: the output of the selected exporter, usually a structure including an onnx model @@ -55,6 +58,10 @@ def to_onnx( exporter=exporter, filename=filename, ) + + Some examples using control flows are available in + :func:`onnx_diagnostic.export.control_flow.loop_for` or + :class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`. """ if exporter == "custom": from experimental_experiment.torch_interpreter import ( @@ -63,16 +70,38 @@ def to_onnx( ) from experimental_experiment.xbuilder import OptimizationOptions - if use_control_flow_dispatcher: - from .control_flow import create_global_dispatcher - - dispatcher = create_global_dispatcher() - options = None if exporter_kwargs is not None: options = exporter_kwargs.pop("options", None) if options is None: options = OptimizationOptions(patterns="default+onnxruntime") + if onnx_plugs or use_control_flow_dispatcher: + from experimental_experiment.torch_interpreter import Dispatcher + + if use_control_flow_dispatcher: + from .control_flow import create_global_dispatcher + + control_flow_dispatcher = create_global_dispatcher() + else: + control_flow_dispatcher = None + + class MainDispatcher(Dispatcher): + def __init__(self): + super().__init__({}) + + main_dispatcher = MainDispatcher() + if control_flow_dispatcher: + main_dispatcher.registered_functions.update( + control_flow_dispatcher.registered_functions + ) + if onnx_plugs: + for plug in onnx_plugs: + main_dispatcher.registered_functions[plug.target_name] = ( + plug.custom_converter() + ) + + else: + main_dispatcher = None return _to_onnx( mod, @@ -89,7 +118,7 @@ def to_onnx( export_options=ExportOptions(save_ep=save_ep), options=options, **(exporter_kwargs or {}), - dispatcher=dispatcher if use_control_flow_dispatcher else None, + dispatcher=main_dispatcher, ) if exporter in ("dynamo", "onnx-dynamo"): @@ -99,6 +128,10 @@ def to_onnx( assert ( not output_dynamic_shapes ), f"output_dynamic_shapes not supported for exporter={exporter!r}" + custom_translation_table = {} + if onnx_plugs: + for plug in onnx_plugs: + custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter() epo = torch.onnx.export( mod, args=args or tuple(), @@ -111,6 +144,7 @@ def to_onnx( verbose=verbose, dump_exported_program=bool(save_ep), artifacts_dir=os.path.dirname(filename) if filename else ".", + custom_translation_table=custom_translation_table, **(exporter_kwargs or {}), ) if optimize: diff --git a/onnx_diagnostic/export/control_flow.py b/onnx_diagnostic/export/control_flow.py index 9b8b927d..f20ff66e 100644 --- a/onnx_diagnostic/export/control_flow.py +++ b/onnx_diagnostic/export/control_flow.py @@ -36,7 +36,7 @@ def register(self, aten_name: str, converter: Callable): @contextlib.contextmanager def enable_code_export_control_flow(): - """Enables the code means to be exported.""" + """Enables the code meant to be exported.""" global _TEST_EXPORT old = _TEST_EXPORT _TEST_EXPORT = True diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py new file mode 100644 index 00000000..7ed39fa0 --- /dev/null +++ b/onnx_diagnostic/export/onnx_plug.py @@ -0,0 +1,247 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple +import onnx +import torch +from ..helpers import max_diff +from ..helpers.torch_helper import torch_dtype_to_onnx_dtype +from ..reference import OnnxruntimeEvaluator + +TUPLE_TENSORS = Tuple[torch.Tensor, ...] + + +def is_exporting() -> bool: + """ + Returns :func:`torch.compiler.is_exporting` or + :func:`torch.compiler.is_compiling`. + Changes ``_TEST_EXPORT`` to make it trigger. + """ + return torch.compiler.is_exporting() or torch.compiler.is_compiling() + + +@dataclass +class VerifyResult: + """ + Outputs of method :meth:`verify + `. + """ + + eager_outputs: TUPLE_TENSORS + onnx_output: TUPLE_TENSORS + diffs: Tuple[Dict[str, float]] + + +class EagerDirectReplacementWithOnnx: + """ + Replaces a piece of code by another one written in ONNX + at export time. The function inserts a custom operator + and links it to the eager_fn + + :param eager_fn: the code it replaces, it must be given in order to be able + to execute the torch.fx.Graph the exporter produces + :param shape_fn: the function produces dummy outputs with the shapes + the exporter can use for the next operators in the graph + :param function_proto: instances of ``onnx.FunctionProto``, + its domain must be ``onnx_plug`` + :param n_inputs: number of inputs of the function, if not given, + the class will infer it from eager_fn signature, + only tensors must be counted + :param n_outputs: same for the number of outputs, + only tensors must be counted + :param name: the name of the custom op, the function name if not specified + """ + + def __init__( + self, + eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS], + shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS], + function_proto: onnx.FunctionProto, + n_inputs: Optional[int] = None, + n_outputs: Optional[int] = None, + name: Optional[str] = None, + ): + assert isinstance( + function_proto, onnx.FunctionProto + ), f"Unexpected type {type(function_proto)} for function_proto" + assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}" + assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}" + self.eager_fn = eager_fn + self.shape_fn = shape_fn + self.function_proto = function_proto + self.n_inputs = n_inputs + self.n_outputs = n_outputs + self.name = name or eager_fn.__name__ + sig = inspect.signature(self.eager_fn) + params = list(sig.parameters) + assert ( + len(params) >= n_inputs + ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}" + assert n_inputs == len(function_proto.input), ( + f"Input mismatch n_inputs={n_inputs} but " + f"function_proto.input={function_proto.input}" + ) + assert n_outputs == len(function_proto.output), ( + f"Output mismatch n_outputs={n_outputs} but " + f"function_proto.output={function_proto.output}" + ) + assert ( + function_proto.domain == self.domain + ), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}" + self.arg_names = params + self.custom_op = self._registers() + + @property + def domain(self) -> str: + "Returns the onnx domain." + return "onnx_plug" + + @property + def target_name(self) -> str: + "Returns the target name (see in the exported program)." + return f"{self.domain}::{self.name}" + + @property + def torch_op(self) -> Callable: + "Returns ``torch.ops.onny_plug." + return getattr(getattr(torch.ops, self.domain), self.name).default + + def __call__(self, *args): + """Calls eager_fn or shape_fn if the model is being exported.""" + if is_exporting(): + return self.shape_fn(*args) + return self.eager_fn(*args) + + def _registers(self): + """Registers the custom op.""" + inputs = ", ".join([f"Tensor {p}" for p in self.arg_names]) + schema = f"({inputs}) -> Tensor" + if self.n_outputs > 1: + schema += "[]" + custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn) + custom_def.register_kernel(None)(self.eager_fn) + custom_def._abstract_fn = self.shape_fn + + def verify(self, *args, engine: Optional[Callable] = None) -> VerifyResult: + """ + Verifies that the eager mode is equivalent to the onnx function given + as a replacements. This function evaluates `eager_fn`, checks that the shapes + are equivalent to the ones given by `shape_fn`, and finally evaluates the + onnx translation if the previous did not fail. + + :param args: function inputs + :param engine: by default an instance of + :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`. + :return: outputs of :func:`onnx_diagnostic.helpers.max_diff` + """ + expected = self.eager_fn(*args) + shapes = self.shape_fn(*args) + if isinstance(expected, torch.Tensor): + expected = (expected,) + assert isinstance(shapes, torch.Tensor), ( + f"eager_fn={self.eager_fn} returns a Tensor but shape_fn={self.shape_fn} " + f"returns a {type(shapes)}" + ) + shapes = (shapes,) + assert isinstance(expected, tuple) and isinstance(shapes, tuple), ( + f"eager_fn={self.eager_fn} returns a {type(expected)} " + f"and shape_fn={self.shape_fn} returns a {type(shapes)}" + ) + assert len(expected) and len(shapes), ( + f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn} " + f"do not return the same number of tensors." + ) + for i, (e, s) in enumerate(zip(expected, shapes)): + assert e.dtype == s.dtype, ( + f"Type mismatch {e.dtype} != {s.dtype} for output {i}, " + f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}" + ) + assert e.shape == s.shape, ( + f"Type mismatch {e.shape} != {s.shape} for output {i}, " + f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}" + ) + + # Now the ONNX execution. + assert engine is None, f"Not implemented yet with engine={engine!r}" + sess = OnnxruntimeEvaluator(self.function_proto) + feeds = dict(zip(sess.input_names, args)) + got = sess.run(None, feeds) + diffs = [max_diff(e, g) for e, g in zip(expected, got)] + return VerifyResult(eager_outputs=expected, onnx_output=tuple(got), diffs=diffs) + + def custom_converter( + self, + ) -> Callable: + """ + Returns a function which + converts a custom ops found in the fx graph into ONNX + following the API of the custom exporter. + The converter adds a custom op and registers the local function. + """ + + def converter( + g: "GraphBuilder", # noqa: F821 + sts: Optional[Dict[str, Any]], + outputs: List[str], + *args, + ) -> Any: + if not g.has_local_function(self.name, self.domain): + g.add_function(self.function_proto) + res = g.make_node( + self.name, args, outputs, domain=self.domain, name=self.target_name + ) + if not sts: + new_shapes = self.shape_fn(*args) + if not isinstance(new_shapes, tuple): + new_shapes = (new_shapes,) + for sh, o in zip(new_shapes, outputs): + g.set_type(o, torch_dtype_to_onnx_dtype(sh.dtype)) + g.set_shape(o, sh.shape) + return res + + return converter + + def onnx_dynamo_converter(self) -> Callable: + """ + Returns a function which + which converts a custom ops found in the fx graph into ONNX + following the API of :func:`torch.onnx.export`. + """ + import onnxscript + + onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1) + schema = onnx_plug_op[self.name] + if schema is None: + all_types = [ + "tensor(float)", + "tensor(float16)", + "tensor(bfloat16)", + "tensor(double)", + "tensor(int64)", + "tensor(int32)", + ] + type_constraints = [] + for i in range(self.n_inputs): + type_constraints.append((f"T{i}", all_types, "")) + for i in range(self.n_outputs): + type_constraints.append((f"U{i}", all_types, "")) + schema = onnx.defs.OpSchema( + self.name, + self.domain, + 1, + inputs=[ + onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}") + for i in range(self.n_inputs) + ], + outputs=[ + onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}") + for i in range(self.n_outputs) + ], + type_constraints=type_constraints, + ) + onnx.defs.register_schema(schema) + op = onnxscript.values.Op(onnx_plug_op, self.name, schema) + + def converter(*cargs): + return op(*cargs, n_outputs=self.n_outputs) + + return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter) From 4a667960f654d324ea072cfb58e7a576277a1a91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 16:01:35 +0100 Subject: [PATCH 4/6] fix test --- _unittests/ut_export/test_control_flow.py | 20 ++++-- _unittests/ut_export/test_onnx_plug.py | 4 +- onnx_diagnostic/export/api.py | 29 ++++++-- onnx_diagnostic/export/control_flow.py | 3 +- onnx_diagnostic/export/onnx_plug.py | 82 +++++++++++++++++++++++ 5 files changed, 128 insertions(+), 10 deletions(-) diff --git a/_unittests/ut_export/test_control_flow.py b/_unittests/ut_export/test_control_flow.py index 3135e9b6..6e10c708 100644 --- a/_unittests/ut_export/test_control_flow.py +++ b/_unittests/ut_export/test_control_flow.py @@ -66,7 +66,10 @@ def body(i, x): ep = torch.export.export( model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) - self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep)) + self.assertIn( + "torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_", + str(ep), + ) onx = to_onnx( model, @@ -97,7 +100,10 @@ def body(i, x): ep = torch.export.export( model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) - self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep)) + self.assertIn( + "torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_different_opset_L_Model_forward_L_body_", + str(ep), + ) onx = to_onnx( model, @@ -132,7 +138,10 @@ def body(i, x): ep = torch.export.export( model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) - self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep)) + self.assertIn( + "torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_", + str(ep), + ) onx = to_onnx( model, @@ -164,7 +173,10 @@ def body(i, x): ep = torch.export.export( model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) - self.assertIn("torch.ops.onnx_higher_ops.loop_for_body_", str(ep)) + self.assertIn( + "torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_", + str(ep), + ) onx = to_onnx( model, diff --git a/_unittests/ut_export/test_onnx_plug.py b/_unittests/ut_export/test_onnx_plug.py index 9d990e4c..97f101f7 100644 --- a/_unittests/ut_export/test_onnx_plug.py +++ b/_unittests/ut_export/test_onnx_plug.py @@ -1,7 +1,7 @@ import unittest import onnx.helper as oh import torch -from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx from onnx_diagnostic.export.api import to_onnx @@ -85,6 +85,8 @@ def forward(self, x): ) self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,)) + if not has_torch("2.9"): + raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8") with self.subTest(exporter="onnx-dynamo"): onx = to_onnx( model, diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 92a409bf..c6b615eb 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -20,6 +20,7 @@ def to_onnx( optimize: bool = True, use_control_flow_dispatcher: bool = False, onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None, + inline: bool = True, ) -> Any: """ Common API for exporters. By default, the models are optimized to use the @@ -44,6 +45,7 @@ def to_onnx( :param use_control_flow_dispatcher: use the dispatcher created to supported custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`) :param onnx_plugs: the code was modified to replace some parts with onnx translation + :param inline: inline local functions :return: the output of the selected exporter, usually a structure including an onnx model @@ -63,6 +65,11 @@ def to_onnx( :func:`onnx_diagnostic.export.control_flow.loop_for` or :class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`. """ + if exporter_kwargs and "inline" in exporter_kwargs: + assert ( + inline == exporter_kwargs["inline"] + ), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}" + exporter_kwargs.pop("inline") if exporter == "custom": from experimental_experiment.torch_interpreter import ( to_onnx as _to_onnx, @@ -90,7 +97,7 @@ def __init__(self): super().__init__({}) main_dispatcher = MainDispatcher() - if control_flow_dispatcher: + if use_control_flow_dispatcher: main_dispatcher.registered_functions.update( control_flow_dispatcher.registered_functions ) @@ -99,7 +106,6 @@ def __init__(self): main_dispatcher.registered_functions[plug.target_name] = ( plug.custom_converter() ) - else: main_dispatcher = None @@ -117,8 +123,9 @@ def __init__(self): output_dynamic_shapes=output_dynamic_shapes, export_options=ExportOptions(save_ep=save_ep), options=options, - **(exporter_kwargs or {}), + inline=inline, dispatcher=main_dispatcher, + **(exporter_kwargs or {}), ) if exporter in ("dynamo", "onnx-dynamo"): @@ -147,7 +154,21 @@ def __init__(self): custom_translation_table=custom_translation_table, **(exporter_kwargs or {}), ) - if optimize: + if not inline and optimize: + ort_fusions.optimize_for_ort(epo.model) + + if onnx_plugs: + import onnx_ir as ir + import onnx_ir.passes.common as common_passes + + irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs] + for func in irfunctions: + epo.model.functions[func.identifier()] = func + if inline: + common_passes.InlinePass()(epo.model) + common_passes.RemoveUnusedOpsetsPass()(epo.model) + + if inline and optimize: ort_fusions.optimize_for_ort(epo.model) if filename: epo.save(filename, external_data=True) diff --git a/onnx_diagnostic/export/control_flow.py b/onnx_diagnostic/export/control_flow.py index f20ff66e..7e2b7298 100644 --- a/onnx_diagnostic/export/control_flow.py +++ b/onnx_diagnostic/export/control_flow.py @@ -134,7 +134,8 @@ def make_custom_loop_for( assert body_outputs is not None, "body_outputs cannot be None" srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs) sred = "x".join(map(str, reduction_dim)) if reduction_dim else "" - name = f"loop_for_{body_fn.__name__}_{id(body_fn)}_{srank}_{sred}" + full_name = body_fn.__qualname__.replace("", "L").replace(".", "_") + name = f"loop_for_{full_name}_{srank}_{sred}" if name in _REGISTERED_SCHEMA: return name, _REGISTERED_SCHEMA[name][0] sig = inspect.signature(body_fn) diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 7ed39fa0..01fa45e1 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -49,6 +49,88 @@ class EagerDirectReplacementWithOnnx: :param n_outputs: same for the number of outputs, only tensors must be counted :param name: the name of the custom op, the function name if not specified + + Here is an example: + + .. runpython:: + :showcode: + + import onnx.helper as oh + import torch + from onnx_diagnostic.helpers.onnx_helper import pretty_onnx + from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx + from onnx_diagnostic.export.api import to_onnx + from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str + + + def demo_customsub(x, y): + return x - y + + + def demo_customsub_shape(x, y): + return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype) + + + def make_function_proto(): + return oh.make_function( + "onnx_plug", + "demo_customsub", + ["x", "y"], + ["z"], + [oh.make_node("Sub", ["x", "y"], ["z"])], + opset_imports=[oh.make_opsetid("", 22)], + ) + + + class Model(torch.nn.Module): + def forward(self, x): + y = x.sum(axis=1, keepdim=True) + d = torch.ops.onnx_plug.demo_customsub(x, y) + return torch.abs(d) + + + replacements = [ + EagerDirectReplacementWithOnnx( + demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1 + ) + ] + + x = torch.randn((3, 4), dtype=torch.float32) + model = Model() + ds = ({0: "d1", 1: "d2"},) + + # The exported program shows a custom op. + ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds)) + print("ep") + + # As the exporter knows how the replace this custom op. + # Let's export. + + onx = to_onnx( + model, + (x,), + dynamic_shapes=ds, + exporter="custom", + onnx_plugs=replacements, + target_opset=22, + inline=False, + ).model_proto + + print(pretty_onnx(onx)) + + # And with :func:`torch.onnx.export`: + + onx = to_onnx( + model, + (x,), + dynamic_shapes=ds, + exporter="onnx-dynamo", + onnx_plugs=replacements, + target_opset=22, + inline=False, + ).model_proto + + print(pretty_onnx(onx)) """ def __init__( From 6d2385343193437a914ed561461efd1eeeee56bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 16:08:07 +0100 Subject: [PATCH 5/6] mypy --- onnx_diagnostic/export/onnx_plug.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 01fa45e1..b2d72252 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -28,7 +28,7 @@ class VerifyResult: eager_outputs: TUPLE_TENSORS onnx_output: TUPLE_TENSORS - diffs: Tuple[Dict[str, float]] + diffs: Tuple[Dict[str, float], ...] class EagerDirectReplacementWithOnnx: @@ -247,8 +247,8 @@ def verify(self, *args, engine: Optional[Callable] = None) -> VerifyResult: sess = OnnxruntimeEvaluator(self.function_proto) feeds = dict(zip(sess.input_names, args)) got = sess.run(None, feeds) - diffs = [max_diff(e, g) for e, g in zip(expected, got)] - return VerifyResult(eager_outputs=expected, onnx_output=tuple(got), diffs=diffs) + diffs = tuple(max_diff(e, g) for e, g in zip(expected, got)) + return VerifyResult(eager_outputs=expected, onnx_output=tuple(got), diffs=diffs) # type: ignore[arg-type] def custom_converter( self, @@ -261,7 +261,7 @@ def custom_converter( """ def converter( - g: "GraphBuilder", # noqa: F821 + g: Any, # GraphBuilder sts: Optional[Dict[str, Any]], outputs: List[str], *args, From 69e8436b1bd9f9815dbed8d7ab7e8cddbdc3e19f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 16:30:41 +0100 Subject: [PATCH 6/6] fix inh --- onnx_diagnostic/export/api.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index c6b615eb..8ee7a84f 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -93,14 +93,33 @@ def to_onnx( control_flow_dispatcher = None class MainDispatcher(Dispatcher): - def __init__(self): + def __init__(self, previous_dispatcher=None): super().__init__({}) - - main_dispatcher = MainDispatcher() - if use_control_flow_dispatcher: - main_dispatcher.registered_functions.update( - control_flow_dispatcher.registered_functions - ) + self.previous_dispatcher = previous_dispatcher + + @property + def supported(self): + if self.previous_dispatcher: + return ( + set(self.registered_functions) | self.previous_dispatcher.supported + ) + return set(self.registered_functions) + + def find_function(self, name: Any): + if self.previous_dispatcher: + find = self.previous_dispatcher.find_function(name) + if find: + return find + return Dispatcher.find_function(self, name) + + def find_method(self, name: Any): + if self.previous_dispatcher: + find = self.previous_dispatcher.find_method(name) + if find: + return find + return Dispatcher.find_method(self, name) + + main_dispatcher = MainDispatcher(control_flow_dispatcher) if onnx_plugs: for plug in onnx_plugs: main_dispatcher.registered_functions[plug.target_name] = (