From e31b475714990f9b1f2a9f38155e6f782c651cf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 13:55:48 +0100 Subject: [PATCH 01/13] unittest --- .../ut_torch_onnx/test_discrepancies.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 _unittests/ut_torch_onnx/test_discrepancies.py diff --git a/_unittests/ut_torch_onnx/test_discrepancies.py b/_unittests/ut_torch_onnx/test_discrepancies.py new file mode 100644 index 00000000..01c755d1 --- /dev/null +++ b/_unittests/ut_torch_onnx/test_discrepancies.py @@ -0,0 +1,54 @@ +import os +import unittest +import numpy as np +import onnx +from onnx_diagnostic.ext_test_case import ExtTestCase + + +class TestDiscrepancies(ExtTestCase): + def test_attention_opset15_in_a_loop(self): + model = onnx.load( + os.path.join(os.path.dirname(__file__), "data", "attention_loopa24.onnx") + ) + sess = self.check_ort(model) + feeds = dict( + c_lift_tensor_0=np.array([0], dtype=np.int64), + cat_2=np.array( + [ + 0, + 64, + 128, + 192, + 256, + 304, + 368, + 432, + 496, + 560, + 608, + 672, + 736, + 800, + 864, + 912, + 976, + 1040, + 1104, + 1168, + 1216, + 1232, + 1248, + 1264, + 1280, + 1292, + ], + dtype=np.int64, + ), + unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32), + ) + got = sess.run(None, feeds) + self.assertEqual(len(got), 1) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 384166b8ed9fc55b9993dda8432532c0cd3398f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 13:56:05 +0100 Subject: [PATCH 02/13] add model --- .../ut_torch_onnx/data/attention_loopa24.onnx | Bin 0 -> 6040 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 _unittests/ut_torch_onnx/data/attention_loopa24.onnx diff --git a/_unittests/ut_torch_onnx/data/attention_loopa24.onnx b/_unittests/ut_torch_onnx/data/attention_loopa24.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9694ffe0cbaddfc1bb56fad4f19a60b68dae0e67 GIT binary patch literal 6040 zcmdT|&u`mQ9FNnaNxwc5?jU&4vgv5nEfD_`vi38$1;J^uq3n%!#XFGoW(ROG8jYyH>m+!Mb-`}6_ z%jPzITY=A$kZvosEJ3j?cl38{O_goUu*|+(ft+sWc2l;h@^Urie=;eX#(`{g^tLA3 zrYy~pfuZjYG`VdWmewBF`aMld@6i9H+4-Hfm#=m@mn!pHW_1OYZ#OsItOC8+NVH0j z(GB~sr&;8}*3H|E`rM_t*lVi@S#w}JqG=}5T7j%?SbGZHEtVDr)kb-+{z`RWX%U}N zs)jGE2^hmixEqAat(PIM8~2pXKx1p^)+MP{Yf6<`V{NtG zY_6_Vnh>@JDKP_~HK>V?+d4ZHwVWi5USm7UIY=o7x;2w{EF@1bP7~A@iIq#puqg1>Xx^G(ENEZV+kv1Eu2xQH&{SV){B2kcwXLtaHJ0w9W-C zSW=s}Zf&=|E$yccQA#4#Cqg4JjF zgd7ie1wKkbtS%(q9>|uq-@#%fS7(#PnwYw#Sax{~lD0WBN>imrqvgN+qFD;QS={mz zwnt5Bo#lKR)VFnL1E>UootnX=y4dTouP|d-vrcqs}7DeCr+ zJ8o>>x_UbZ&RJUE+;I-;;9Rw_G`Fy)tEy(m_jT3YTUeS~t1M!h;yMSD3`^0SC<)&Q z5P?nd*mimYTV`zrm5{JG(zYe#Sde0a+}V2y&`#HjfIb2WD$IdSmFi22ttA*oC|t`s zR+oR4cS%7?!OVcKJm}v!|3#6*8NiP%A4WoX=w+B^5R(8=%t5wGG!R zIX_7fL5DcHRmfNSmaOfaHupSMDixdrz|{97!D68SF~qo}(EP!ub3bAoWQW>!PgK zwn5*#Dvs}9qtcAFcD>uP56gK-tCT;@2((6K86qcH0&cLnAjRo`#s(HqBo77XqnWQ$3V^BJh3^9qO1Xj%TKHxvI{-#Bya*!2B{ z9j8dVfy!!qV|kg@g}1!YAzk!Cuer-&&oe)jssB^+pW-57%!vqOhvEF9OJ_oqDdRFq zq#vQg%a;~h#gVSULo((?QJvKXKhx&DOgn>_=1T=~4HXKwFea*lZj5Bsmjd|D24;6M zF@)qKSJJ3GYi=E3$uOMO3n4pBQakbn`&w{7LI@Hk_K|ItE!6W#}6UT z=6`oEi1mL2HxC+hfNw*c+%tx?XysP5qYl^+6^@E%&Pw~H|Txz;?G*mrm5N$@Lb_vgMH2+9vhrXRw(g^^l%BugzKK* zB{C5FAi(Hd?NB~Ju{4#TYOVc0)S@0^6yqetGL95gZT8vm%aN#TIASx@nYWplF%z+Fj*9#<{JnTe!%Yd3g%*ckjuRdhBV6YRwd|!3blZ>g6g&qmldo_T;LpdH}QqIdBncZx&$Lq;K$&5#6L@_^B$F;&sZ-|TwM z9Cow_{mlC@hlDfaOS~QWGUhnsSDYESw?kmmBIow+<0NOPJ%!R@CynF|vQR~;Q*W<{ z8JAKHXN74W%oq)wn6~QF@W3G&m1GmvQ%{O`dKo-@yhtAr*9x4;63P8ww`;0er+h9e zAQ$8hX--k(AM literal 0 HcmV?d00001 From 45156515093600fb5f4cbb231f46341ef1f0c353 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 13:44:38 +0000 Subject: [PATCH 03/13] fix other series of bugs --- .../ut_torch_onnx/data/attention_loopa24.onnx | Bin 6040 -> 7463 bytes .../ut_torch_onnx/test_discrepancies.py | 5 +- onnx_diagnostic/helpers/dot_helper.py | 26 +------- onnx_diagnostic/helpers/onnx_helper.py | 61 +++++++++++++++--- onnx_diagnostic/reference/ort_evaluator.py | 35 ++-------- onnx_diagnostic/torch_onnx/runtime_info.py | 25 +------ 6 files changed, 67 insertions(+), 85 deletions(-) diff --git a/_unittests/ut_torch_onnx/data/attention_loopa24.onnx b/_unittests/ut_torch_onnx/data/attention_loopa24.onnx index 9694ffe0cbaddfc1bb56fad4f19a60b68dae0e67..a904be2a4c92d8fa614f92046c8183df20ebfc64 100644 GIT binary patch delta 918 zcmbQCzubzIgWKwtr7vvYGrf?ygT{x<+a8M1D~jGaUsB z15E{e1+d%Hkew?DR3K^LCCkN=nTO Set[str]: - hidden = set() - memo = ( - {i.name for i in graph.initializer} - | {i.values.name for i in graph.sparse_initializer} - | {i.name for i in graph.input} - ) - for node in graph.node: - for i in node.input: - if i not in memo: - hidden.add(i) - for att in node.attribute: - if att.type == onnx.AttributeProto.GRAPH and att.g: - hid = _get_hidden_inputs(att.g) - less = set(h for h in hid if h not in memo) - hidden |= less - memo |= set(node.output) - return hidden +from .onnx_helper import onnx_dtype_name, pretty_onnx, get_hidden_inputs def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str: @@ -221,7 +201,7 @@ def _mkn(obj: object) -> int: unique = set() for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: - unique |= _get_hidden_inputs(att.g) + unique |= get_hidden_inputs(att.g) for i in unique: edge = name_to_ids[i], _mkn(node) # type: ignore[assignment] if edge in done: diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 85713459..bf4fe528 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1198,6 +1198,30 @@ def shadowing_names( return shadow, post_shadow, created +def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: + """ + Returns the hidden inputs (inputs coming from an upper context) + used by a subgraph. + """ + hidden = set() + memo = ( + set(i.name for i in graph.initializer) + | set(i.name for i in graph.sparse_initializer) + | set(i.name for i in graph.input) + ) + for node in graph.node: + for i in node.input: + if i not in memo: + hidden.add(i) + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH and att.g: + hid = get_hidden_inputs(att.g) + less = set(h for h in hid if h not in memo) + hidden |= less + memo |= set(node.output) + return hidden + + def extract_subset_of_nodes( model: ModelProto, name: str, @@ -1240,14 +1264,28 @@ def extract_subset_of_nodes( current_input_index = 0 intermediate = {name} cut_points -= {name} + cached = {} inputs = set(k for k in node.input if k) while not (inputs <= cut_points) and current_node_index >= 0: node = model.graph.node[current_node_index] - if current_input_index == 0 or not node.input: + # node inputs including hidden ones + if current_node_index in cached: + node_inputs = cached[current_node_index] + else: + node_inputs = set(i for i in node.input if i) + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + node_inputs |= get_hidden_inputs(att.g) + node_inputs = list(node_inputs) + cached[current_node_index] = node_inputs + # processing + if current_input_index == 0 or not node_inputs: needs = [o for o in node.output if o in intermediate and o not in cut_points] if needs: selected.add(current_node_index) - if not node.input: + if not node_inputs: current_node_index -= 1 current_input_index = 0 continue @@ -1255,15 +1293,16 @@ def extract_subset_of_nodes( current_node_index -= 1 current_input_index = 0 continue - assert current_input_index < len(node.input), ( - f"current_input_index={current_input_index} but node.input={node.input}, " + # more intermediate results + assert current_input_index < len(node_inputs), ( + f"current_input_index={current_input_index} but node_inputs={node_inputs}, " f"node={pretty_onnx(node)}" ) - res = node.input[current_input_index] + res = node_inputs[current_input_index] if res not in cut_points: intermediate.add(res) current_input_index += 1 - if current_input_index >= len(node.input): + if current_input_index >= len(node_inputs): current_node_index -= 1 current_input_index = 0 @@ -1296,8 +1335,14 @@ def _mkv_(name, itype, irank): not_known: Set[str] = set() for node in nodes[::-1]: - not_known -= set(node.output) - not_known |= set(node.input) + not_known -= {o for o in node.output if o} + not_known |= {i for i in node.input if i} + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + print("++++", get_hidden_inputs(att.g)) + not_known |= get_hidden_inputs(att.g) model = oh.make_model( oh.make_graph( diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index 6d7f20d6..ffa511d6 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -18,10 +18,11 @@ import onnxruntime from ..helpers import string_type from ..helpers.onnx_helper import ( - pretty_onnx, + get_hidden_inputs, dtype_to_tensor_dtype, - to_array_extended, np_dtype_to_tensor_dtype, + to_array_extended, + pretty_onnx, ) from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype from ..helpers.ort_session import ( @@ -472,39 +473,15 @@ def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]: yield from self.enumerate_nodes(att.g.node) yield node - @classmethod - def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]: - """ - Returns the hidden inputs (inputs coming from an upper context) - used by a subgraph. - """ - hidden = set() - memo = ( - {i.name for i in graph.initializer} - | {i.name for i in graph.sparse_initializer} - | {i.name for i in graph.input} - ) - for node in graph.node: - for i in node.input: - if i not in memo: - hidden.add(i) - for att in node.attribute: - if att.type == AttributeProto.GRAPH and att.g: - hid = cls._get_hidden_inputs(att.g) - less = set(h for h in hid if h not in memo) - hidden |= less - memo |= set(node.output) - return hidden - @classmethod def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]: - """Calls multiple _get_hidden_inputs on every attribute.""" + """Calls multiple get_hidden_inputs on every attribute.""" if node.op_type not in {"Loop", "Scan", "If"}: return set() hidden = set() for att in node.attribute: if att.type == AttributeProto.GRAPH: - hidden |= cls._get_hidden_inputs(att.g) + hidden |= get_hidden_inputs(att.g) return hidden - (hidden & set(node.input)) def _get_sess( @@ -624,7 +601,7 @@ def _get_sess_init_subgraph( value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) vinputs.append(value) - reduced_set = self._get_hidden_inputs(g) + reduced_set = get_hidden_inputs(g) for i, v in context.items(): if i in reduced_set and i not in unique_names: unique_names.add(i) diff --git a/onnx_diagnostic/torch_onnx/runtime_info.py b/onnx_diagnostic/torch_onnx/runtime_info.py index 86664dd9..5d3ed0a4 100644 --- a/onnx_diagnostic/torch_onnx/runtime_info.py +++ b/onnx_diagnostic/torch_onnx/runtime_info.py @@ -4,6 +4,7 @@ import torch from ..api import TensorLike from ..helpers import string_type +from ..helpers.onnx_helper import get_hidden_inputs class RuntimeValueKind(enum.IntEnum): @@ -151,30 +152,6 @@ def is_initializer(self) -> bool: return self.kind == RuntimeValueKind.INITIALIZER -def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: - """ - Returns the hidden inputs (inputs coming from an upper context) - used by a subgraph. - """ - hidden = set() - memo = ( - set(i.name for i in graph.initializer) - | set(i.name for i in graph.sparse_initializer) - | set(i.name for i in graph.input) - ) - for node in graph.node: - for i in node.input: - if i not in memo: - hidden.add(i) - for att in node.attribute: - if att.type == onnx.AttributeProto.GRAPH and att.g: - hid = get_hidden_inputs(att.g) - less = set(h for h in hid if h not in memo) - hidden |= less - memo |= set(node.output) - return hidden - - def set_is_shape( node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None ) -> List[str]: From 3e9216681d1c8d946702ffcc2a9d5656b395e123 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 14:29:54 +0000 Subject: [PATCH 04/13] mystery --- .../ut_torch_onnx/test_discrepancies.py | 72 ++++++++++++++++++- .../patches/_patch_transformers_attention.py | 1 + 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_torch_onnx/test_discrepancies.py b/_unittests/ut_torch_onnx/test_discrepancies.py index 165ae4eb..b08a0453 100644 --- a/_unittests/ut_torch_onnx/test_discrepancies.py +++ b/_unittests/ut_torch_onnx/test_discrepancies.py @@ -2,15 +2,53 @@ import unittest import numpy as np import onnx -from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings +from onnx_diagnostic.reference import OnnxruntimeEvaluator class TestDiscrepancies(ExtTestCase): + @ignore_warnings(DeprecationWarning) def test_attention_opset15_in_a_loop(self): + import torch + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_attention import ( # noqa: E501 + patched_sdpa_attention_forward, + ) + + def qwen_sdpa_attention( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cu_seqlens: torch.Tensor, + scaling: float = 0, + num_heads: int = 16, + ) -> torch.Tensor: + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) + for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + patched_sdpa_attention_forward( + None, + q, + k, + v, + attention_mask=None, + scaling=scaling, + dropout=0.0, + is_causal=False, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + return attn_output + model = onnx.load( os.path.join(os.path.dirname(__file__), "data", "attention_loopa24.onnx") ) sess = self.check_ort(model) + feeds = dict( c_lifted_tensor_0=np.array([0], dtype=np.int64), cat_2=np.array( @@ -48,9 +86,41 @@ def test_attention_opset15_in_a_loop(self): unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32), unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32), ) + + dummy_inputs = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "dump_test", + "replay", + "qwen_sdpa_attention_loopa24", + "onnx_inputs.pt", + ) + if os.path.exists(dummy_inputs): + print("-- use dummy inputs") + feeds = {k: v.detach().cpu().numpy() for k, v in torch.load(dummy_inputs).items()} + for k, v in feeds.items(): + print(f"-- {k}: {self.string_type(v, with_shape=True, with_min_max=True)}") + + # feeds["cat_2"] = np.array([0, 1292], dtype=np.int64) got = sess.run(None, feeds) self.assertEqual(len(got), 1) self.assertEqual((1, 1292, 16, 80), got[0].shape) + expected = qwen_sdpa_attention( + torch.from_numpy(feeds["unsqueeze_4"]), + torch.from_numpy(feeds["unsqueeze_5"]), + torch.from_numpy(feeds["unsqueeze_6"]), + torch.from_numpy(feeds["cat_2"]), + scaling=0.11180339753627777, + num_heads=16, + ) + self.assertEqualArray(expected, got[0], atol=1e-5) + + tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()} + ev = OnnxruntimeEvaluator(model) + got2 = ev.run(None, tfeeds) + self.assertEqual(len(got2), 1) + self.assertEqualArray(got[0], got2[0], atol=1e-5) if __name__ == "__main__": diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py index 8be65493..98cdabde 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py @@ -118,6 +118,7 @@ def patched_sdpa_attention_forward( torch._check(value.shape[1] > 0) torch._check(value.shape[2] > 0) torch._check(value.shape[3] > 0) + return ( torch.nn.functional.scaled_dot_product_attention( query, From a9d482c3ab6dc8abbe3143e9c151dd1078306483 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 14:31:44 +0000 Subject: [PATCH 05/13] mypy --- onnx_diagnostic/helpers/onnx_helper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index bf4fe528..7bce38ac 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1264,7 +1264,7 @@ def extract_subset_of_nodes( current_input_index = 0 intermediate = {name} cut_points -= {name} - cached = {} + cached: Dict[int, List[str]] = {} inputs = set(k for k in node.input if k) while not (inputs <= cut_points) and current_node_index >= 0: node = model.graph.node[current_node_index] @@ -1272,13 +1272,13 @@ def extract_subset_of_nodes( if current_node_index in cached: node_inputs = cached[current_node_index] else: - node_inputs = set(i for i in node.input if i) + set_inputs = set(i for i in node.input if i) if node.op_type in {"Scan", "If", "Loop"}: # there are hidden inputs for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: - node_inputs |= get_hidden_inputs(att.g) - node_inputs = list(node_inputs) + set_inputs |= get_hidden_inputs(att.g) + node_inputs = list(set_inputs) cached[current_node_index] = node_inputs # processing if current_input_index == 0 or not node_inputs: From 320aae86431a2093b99083d901d22472efe68343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 15:07:12 +0000 Subject: [PATCH 06/13] improve unittest --- .../ut_torch_onnx/test_discrepancies.py | 197 +++++++++++------- onnx_diagnostic/helpers/onnx_helper.py | 1 - 2 files changed, 123 insertions(+), 75 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_discrepancies.py b/_unittests/ut_torch_onnx/test_discrepancies.py index b08a0453..9d344ff4 100644 --- a/_unittests/ut_torch_onnx/test_discrepancies.py +++ b/_unittests/ut_torch_onnx/test_discrepancies.py @@ -2,8 +2,9 @@ import unittest import numpy as np import onnx -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, has_onnxruntime from onnx_diagnostic.reference import OnnxruntimeEvaluator +from onnx_diagnostic.helpers import max_diff, string_diff class TestDiscrepancies(ExtTestCase): @@ -44,83 +45,131 @@ def qwen_sdpa_attention( attn_output = torch.cat(attn_outputs, dim=1) return attn_output - model = onnx.load( - os.path.join(os.path.dirname(__file__), "data", "attention_loopa24.onnx") - ) - sess = self.check_ort(model) + for model_name in ["attention_loopa24.onnx", "attention_loopmha.onnx"]: + if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.24"): + # not available + continue + with self.subTest(model=model_name): + model = onnx.load(os.path.join(os.path.dirname(__file__), "data", model_name)) + sess = self.check_ort(model) - feeds = dict( - c_lifted_tensor_0=np.array([0], dtype=np.int64), - cat_2=np.array( - [ - 0, - 64, - 128, - 192, - 256, - 304, - 368, - 432, - 496, - 560, - 608, - 672, - 736, - 800, - 864, - 912, - 976, - 1040, - 1104, - 1168, - 1216, - 1232, - 1248, - 1264, - 1280, - 1292, - ], - dtype=np.int64, - ), - unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32), - unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32), - unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32), - ) + feeds = dict( + c_lifted_tensor_0=np.array([0], dtype=np.int64), + cat_2=np.array( + [ + 0, + 64, + 128, + 192, + 256, + 304, + 368, + 432, + 496, + 560, + 608, + 672, + 736, + 800, + 864, + 912, + 976, + 1040, + 1104, + 1168, + 1216, + 1232, + 1248, + 1264, + 1280, + 1292, + ], + dtype=np.int64, + ), + unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32), + unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32), + unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32), + ) - dummy_inputs = os.path.join( - os.path.dirname(__file__), - "..", - "..", - "dump_test", - "replay", - "qwen_sdpa_attention_loopa24", - "onnx_inputs.pt", - ) - if os.path.exists(dummy_inputs): - print("-- use dummy inputs") - feeds = {k: v.detach().cpu().numpy() for k, v in torch.load(dummy_inputs).items()} - for k, v in feeds.items(): - print(f"-- {k}: {self.string_type(v, with_shape=True, with_min_max=True)}") + dummy_inputs = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "dump_test", + "replay", + "qwen_sdpa_attention_loopmha", + "onnx_inputs.pt", + ) + if os.path.exists(dummy_inputs): + print("-- use dummy inputs") - # feeds["cat_2"] = np.array([0, 1292], dtype=np.int64) - got = sess.run(None, feeds) - self.assertEqual(len(got), 1) - self.assertEqual((1, 1292, 16, 80), got[0].shape) - expected = qwen_sdpa_attention( - torch.from_numpy(feeds["unsqueeze_4"]), - torch.from_numpy(feeds["unsqueeze_5"]), - torch.from_numpy(feeds["unsqueeze_6"]), - torch.from_numpy(feeds["cat_2"]), - scaling=0.11180339753627777, - num_heads=16, - ) - self.assertEqualArray(expected, got[0], atol=1e-5) + feeds1 = torch.load(dummy_inputs) + res1 = qwen_sdpa_attention( + feeds1["unsqueeze_4"], + feeds1["unsqueeze_5"], + feeds1["unsqueeze_6"], + feeds1["cat_2"], + scaling=0.11180339753627777, + num_heads=16, + ) + feeds1o = {k: v.detach().cpu().numpy() for k, v in feeds1.items()} + reso1 = sess.run(None, feeds1o)[0] + dummy_inputs2 = dummy_inputs.replace("onnx_inputs", "torch_inputs") + assert dummy_inputs != dummy_inputs2 + feeds2 = torch.load(dummy_inputs2) + res2 = qwen_sdpa_attention( + feeds2["unsqueeze_4"], + feeds2["unsqueeze_5"], + feeds2["unsqueeze_6"], + feeds2["cat_2"], + scaling=0.11180339753627777, + num_heads=16, + ) + feeds2o = {k: v.detach().cpu().numpy() for k, v in feeds2.items()} + reso2 = sess.run(None, feeds2o)[0] + diff = max_diff(res1, res2, hist=[0.1]) + print(f"-- diff torch-onnx: {string_diff(diff)}") + diff = max_diff(res2, reso2, hist=[0.1]) + print(f"-- diff torch-onnxo1: {string_diff(diff)}") + diff = max_diff(res1, reso1, hist=[0.1]) + print(f"-- diff torch-onnxo2: {string_diff(diff)}") + if diff["abs"] > 0.1: + for k in feeds1: + print( + f"-- {k}: " + f"{string_diff(max_diff(feeds1[k], feeds2[k], hist=[0.1]))}" + ) + + feeds = { + k: v.detach().cpu().numpy() + for k, v in torch.load(dummy_inputs).items() + } + + for k, v in feeds.items(): + print( + f"-- {k}: " + f"{self.string_type(v, with_shape=True, with_min_max=True)}" + ) + + # feeds["cat_2"] = np.array([0, 1292], dtype=np.int64) + got = sess.run(None, feeds) + self.assertEqual(len(got), 1) + self.assertEqual((1, 1292, 16, 80), got[0].shape) + expected = qwen_sdpa_attention( + torch.from_numpy(feeds["unsqueeze_4"]), + torch.from_numpy(feeds["unsqueeze_5"]), + torch.from_numpy(feeds["unsqueeze_6"]), + torch.from_numpy(feeds["cat_2"]), + scaling=0.11180339753627777, + num_heads=16, + ) + self.assertEqualArray(expected, got[0], atol=1e-5) - tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()} - ev = OnnxruntimeEvaluator(model) - got2 = ev.run(None, tfeeds) - self.assertEqual(len(got2), 1) - self.assertEqualArray(got[0], got2[0], atol=1e-5) + tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()} + ev = OnnxruntimeEvaluator(model) + got2 = ev.run(None, tfeeds) + self.assertEqual(len(got2), 1) + self.assertEqualArray(got[0], got2[0], atol=1e-5) if __name__ == "__main__": diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 7bce38ac..360e0be2 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1341,7 +1341,6 @@ def _mkv_(name, itype, irank): # there are hidden inputs for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: - print("++++", get_hidden_inputs(att.g)) not_known |= get_hidden_inputs(att.g) model = oh.make_model( From 467b2e45bf393f5332096f519c1ad7c35a033756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 15:55:34 +0000 Subject: [PATCH 07/13] fix --- onnx_diagnostic/_command_lines_parser.py | 15 + onnx_diagnostic/helpers/onnx_helper.py | 359 +++++++++++++++++- onnx_diagnostic/torch_onnx/sbs_dataclasses.py | 27 ++ 3 files changed, 400 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 519e5dc2..ae885966 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1342,6 +1342,20 @@ def get_parser_sbs() -> ArgumentParser: default="replay", help="If the replay is triggered, this defines the folder where everything is dumped.", ) + parser.add_argument( + "-p", + "--replay-prefix-model", + action=BooleanOptionalAction, + default=False, + help=textwrap.dedent( + """ + There are two ways to recompute an intermediate output, the first one is to " + produce the minimal model between torch and onnx. + The second one is to dump onnx models from the inputs + to the considered intermediate results. This enables the second one. + """ + ), + ) return parser @@ -1431,6 +1445,7 @@ def _size(name): set(args.replay_op_types.split(",")) if args.replay_op_types else None ), dump_folder=args.replay_folder, + dump_prefix_model=args.replay_prefix_model, ) print("-- starts side-by-side") diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 360e0be2..fdb817f5 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -3,7 +3,19 @@ import os import sys import warnings -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import numpy as np import numpy.typing as npt import onnx @@ -1222,6 +1234,16 @@ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: return hidden +def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]: + """ + Returns input and hidden inputs of a node. + See :func:`get_hidden_inputs`. + """ + if node.op_type in {"Scan", "Loop", "If"}: + return set(node.input) | get_hidden_inputs(node) + return set(node.input) + + def extract_subset_of_nodes( model: ModelProto, name: str, @@ -1354,3 +1376,338 @@ def _mkv_(name, itype, irank): opset_imports=opset_imports, ) return model + + +def get_tensor_shape( + obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto], +) -> Optional[List[Optional[Union[int, str]]]]: + """ + Returns the shape if that makes sense for this object. + """ + if isinstance(obj, ValueInfoProto): + return get_tensor_shape(obj.type) + elif not isinstance(obj, onnx.TypeProto): + raise TypeError(f"Unexpected type {type(obj)!r}.") + if not obj.tensor_type.HasField("shape"): + return None + shape = [] + for d in obj.tensor_type.shape.dim: + v = d.dim_value if d.dim_value > 0 else d.dim_param + shape.append(v) + if not shape: + return shape + return [None if s in (0, "") else s for s in shape] + + +def _enumerate_model_node_outputs( + model: ModelProto, add_node: bool = False, order: bool = False +) -> Iterable[Union[str, Tuple[str, NodeProto]]]: + """ + Enumerates all the nodes of a model. + + :param model: :epkg:`ONNX` graph + :param add_node: if False, the function enumerates + all output names from every node, otherwise, it + enumerates tuple (output name, node) + :param order: goes through outputs following the graph order + :return: enumerator + """ + assert hasattr(model, "graph"), "Parameter model is not an ONNX model but {type(model)}" + if order: + edges = [] + dorder = {} + node_names = {} + for inp in model.graph.input: + dorder[0, inp.name] = 0 + for node in model.graph.node: + dorder[1, node.name] = 0 + for i in node.input: + edges.append(("in", i, node.name)) + for o in node.output: + edges.append(("out", o, node.name)) + node_names[o] = node + dorder[0, o] = 0 + + modif = 1 + n_iter = 0 + while modif > 0 and n_iter <= len(model.graph.node): + modif = 0 + n_iter += 1 + for kind, data_name, node_name in edges: + if kind == "in": + if (0, data_name) not in dorder: + continue + if dorder[0, data_name] + 1 > dorder[1, node_name]: + modif += 1 + dorder[1, node_name] = dorder[0, data_name] + 1 + else: + if dorder[1, node_name] + 1 > dorder[0, data_name]: + modif += 1 + dorder[0, data_name] = dorder[1, node_name] + 1 + + orders = [(v, k) for k, v in dorder.items()] + orders.sort() + + for _, k in orders: + if k[0] == 1: + continue + out = k[1] + if out not in node_names: + continue + yield (out, node_names[out]) if add_node else out + else: + for node in model.graph.node: + for out in node.output: + yield (out, node) if add_node else out + + +def onnx_remove_node_unused( + graph: Union[onnx.GraphProto, onnx.FunctionProto], recursive=True +) -> Union[onnx.GraphProto, onnx.FunctionProto]: + """ + Removes unused nodes of the graph. An unused node + is not involved in the output computation. + + :param onnx_model: onnx model + :param recursive: looks into subgraphs + :return: new Graph + """ + is_function = isinstance(graph, FunctionProto) + + # mark outputs + marked = ( + {o: set() for o in graph.output} + if is_function + else {o.name: set() for o in graph.output} + ) + nodes = list(graph.node) + + # mark node output + for node in reversed(nodes): + used = False + for o in node.output: + if o in marked: + for i in get_all_node_inputs(node): + marked[o].add(i) + used = True + if used: + for i in get_all_node_inputs(node): + marked[i] = set() + + # removed nodes + removed = set() + marked_set = set(marked) + for ind, node in enumerate(nodes): + if not (set(node.output) & marked_set): + removed.add(ind) + + if not is_function: + initializers = [i for i in graph.initializer if i.name in marked] + sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked] + new_nodes = [node for i, node in enumerate(nodes) if i not in removed] + + # Finally create the new graph. + if is_function: + return oh.make_function( + graph.domain, + graph.name, + graph.input, + graph.output, + new_nodes, + opset_imports=graph.opset_import, + attributes=graph.attribute, + doc_string=graph.doc_string, + ) + new_graph = oh.make_graph( + new_nodes, + graph.name, + graph.input, + graph.output, + initializers, + sparse_initializers, + ) + new_graph.value_info.extend(graph.value_info) + return new_graph + + +def select_model_inputs_outputs( + model: ModelProto, + outputs: Optional[List[str]] = None, + inputs: Optional[List[str]] = None, + infer_shapes: bool = True, + overwrite: Optional[Dict[str, Any]] = None, + remove_unused: bool = True, + verbose: int = 0, +): + """ + Takes a model and changes its outputs. + + :param model: :epkg:`ONNX` model + :param inputs: new inputs, same ones if None + :param outputs: new outputs, same ones if None + :param infer_shapes: infer inputs and outputs shapes + :param overwrite: overwrite type and shapes for + inputs or outputs, *overwrite* is a + dictionary `{'name': (numpy dtype, shape)}` + :param remove_unused: remove unused nodes from the graph + :param verbose: display information while converting + :return: modified model + + The function removes unneeded nodes. + + The following example shows how to change the inputs of model + to bypass the first nodes. Shape inferences fails to determine + the new inputs type. They need to be overwritten. + `verbose=1` shows the number of deleted nodes. + + :: + + import onnx + from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs + + onx = onnx.load(path) + onx2 = select_model_inputs_outputs( + onx, inputs=["a", "b"], + infer_shapes=True, verbose=1, + overwrite={'a': (numpy.int32, None), 'b': (numpy.int64, None)}) + onnx.save(onx2, path2) + """ + if not isinstance(model, ModelProto): + raise TypeError(f"Unexpected type {type(model)} for model.") + if inputs is not None and not isinstance(inputs, list): + inputs = [inputs] + if outputs is not None and not isinstance(outputs, list): + outputs = [outputs] + if inputs is None: + inputs = [i.name for i in model.graph.input] + if outputs is None: + outputs = [o.name for o in model.graph.output] + + mark_var = {} + for out in _enumerate_model_node_outputs(model): + mark_var[out] = 0 + for inp in inputs: + mark_var[inp] = 0 + for out in outputs: + assert out in mark_var, "Output '{out}' not found in model." + mark_var[out] = 1 + + nodes = list(model.graph.node[::-1]) + mark_op = {} + for node in list(nodes): + mark_op[id(node)] = 0 + + # We mark all the nodes we need to keep. + nb = 1 + while nb > 0: + nb = 0 + for node in nodes: + if mark_op[id(node)] == 1: + continue + mod = False + for out in node.output: + if mark_var[out] == 1: + mark_op[id(node)] = 1 + mod = True + break + if not mod: + continue + + hidden = get_hidden_inputs([node]) + node_inputs = list(node.input) + list(hidden) + + nb += 1 + for inp in node_inputs: + if inp in inputs: + continue + if mark_var.get(inp, 0) == 1: + continue + mark_var[inp] = 1 + nb += 1 + + # All nodes verifies mark_op[node.name] == 1 + keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1] + + known_shapes = {} + if infer_shapes: + shapes = onnx.shape_inference.infer_shapes(model) + for shape in shapes.graph.value_info: + known_shapes[shape.name] = shape.type + for shape in shapes.graph.input: + known_shapes[shape.name] = shape.type + for shape in shapes.graph.output: + known_shapes[shape.name] = shape.type + else: + for shape in model.graph.input: + known_shapes[shape.name] = shape.type + for shape in model.graph.output: + known_shapes[shape.name] = shape.type + + var_in = [] + for name in inputs: + if overwrite is not None and name in overwrite: + dtype, shape = overwrite[name] + proto_dtype = np_dtype_to_tensor_dtype(dtype) + value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + elif name in known_shapes: + info = known_shapes[name].tensor_type + proto_dtype = info.elem_type + if proto_dtype == 0: + value_info = ValueInfoProto() + value_info.name = name + else: + shape = get_tensor_shape(known_shapes[name]) + value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + else: + value_info = ValueInfoProto() + value_info.name = name + var_in.append(value_info) + + var_out = [] + for name in outputs: + if overwrite is not None and name in overwrite: + dtype, shape = overwrite[name] + proto_dtype = np_dtype_to_tensor_dtype(dtype) + value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + elif name in known_shapes: + info = known_shapes[name].tensor_type + proto_dtype = info.elem_type + if proto_dtype == 0: + value_info = ValueInfoProto() + value_info.name = name + else: + shape = get_tensor_shape(known_shapes[name]) + value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + else: + value_info = ValueInfoProto() + value_info.name = name + var_out.append(value_info) + + graph = oh.make_graph( + keep_nodes, + model.graph.name, + var_in, + var_out, + model.graph.initializer, + sparse_initializer=model.graph.sparse_initializer, + ) + if remove_unused: + graph = onnx_remove_node_unused(graph, recursive=False) + onnx_model = oh.make_model(graph, functions=model.functions) + onnx_model.ir_version = model.ir_version + onnx_model.producer_name = model.producer_name + onnx_model.producer_version = model.producer_version + onnx_model.domain = model.domain + onnx_model.model_version = model.model_version + onnx_model.doc_string = model.doc_string + if model.metadata_props: + values = {p.key: p.value for p in model.metadata_props} + oh.set_model_props(onnx_model, values) + + del onnx_model.opset_import[:] + for oimp in model.opset_import: + op_set = onnx_model.opset_import.add() + op_set.domain = oimp.domain + op_set.version = oimp.version + + return onnx_model diff --git a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py index 2b7255d2..20db8f30 100644 --- a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py +++ b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py @@ -61,12 +61,16 @@ class ReplayConfiguration: :param selected_names: list of results names to dump :param selected_op_types: list of onnx operators to dump :param threshold: only keep those whose discrepancies is greater than that threshold + :param dump_prefix_model: aftrer dumping the smallest model able to replicate + one given output, if also dumps the models producing the inputs + and the outputs truncated from the big one """ dump_folder: str selected_names: Optional[Set[str]] = None selected_op_types: Optional[Set[str]] = None threshold: float = 0.1 + dump_prefix_model: bool = False def __post_init__(self): assert self.dump_folder, "dump_folder is empty and this is not allowed for the replay" @@ -318,6 +322,29 @@ def dump( ) with open(os.path.join(folder, "replay.py"), "w") as f: f.write(self.get_replay_code()) + + if self.dump_prefix_model: + from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs + + main_inputs = { + i.name: onnx_inputs.get(i.name, torch_inputs.get(i.name, None)) + for i in model.graph.input + } + # only saving onnx inputs, torch should be the same + torch.save(main_inputs, os.path.join(folder, "onnx_main_inputs.pt")) + + model_inputs_file = os.path.join(folder, "model.inputs.onnx") + model_inputs = select_model_inputs_outputs( + model, inputs=[i.name for i in submodel.graph.input] + ) + onnx.save(model_inputs, model_inputs_file) + + model_outputs_file = os.path.join(folder, "model.inputs.onnx") + model_outputs = select_model_inputs_outputs( + model, outputs=[i.name for i in submodel.graph.output] + ) + onnx.save(model_outputs, model_outputs_file) + if verbose: print(f"[ReplayConfiguration.dump] done {folder!r}") return folder From d306f2e73048c5ddb60a627efed5d62a7ee3efd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 15:55:59 +0000 Subject: [PATCH 08/13] bug --- onnx_diagnostic/torch_onnx/sbs_dataclasses.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py index 20db8f30..c229c294 100644 --- a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py +++ b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py @@ -11,7 +11,12 @@ import onnx import numpy as np import torch -from ..helpers.onnx_helper import extract_subset_of_nodes, make_submodel, from_array_extended +from ..helpers.onnx_helper import ( + extract_subset_of_nodes, + make_submodel, + from_array_extended, + select_model_inputs_outputs, +) from ..helpers.torch_helper import torch_dtype_to_onnx_dtype @@ -324,8 +329,6 @@ def dump( f.write(self.get_replay_code()) if self.dump_prefix_model: - from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs - main_inputs = { i.name: onnx_inputs.get(i.name, torch_inputs.get(i.name, None)) for i in model.graph.input From ce5c1bf8f65ea81a6dfd94b1327d3f5fd341055f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 16:01:09 +0000 Subject: [PATCH 09/13] fin --- onnx_diagnostic/helpers/onnx_helper.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index fdb817f5..fe89d31d 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1239,9 +1239,12 @@ def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]: Returns input and hidden inputs of a node. See :func:`get_hidden_inputs`. """ + start = set(node.input) if node.op_type in {"Scan", "Loop", "If"}: - return set(node.input) | get_hidden_inputs(node) - return set(node.input) + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + start |= get_hidden_inputs(att.g) + return start def extract_subset_of_nodes( @@ -1613,8 +1616,7 @@ def select_model_inputs_outputs( if not mod: continue - hidden = get_hidden_inputs([node]) - node_inputs = list(node.input) + list(hidden) + node_inputs = get_all_node_inputs(node) nb += 1 for inp in node_inputs: From 859f4a9d7b456e12ed2f525d735ce229ab06dba7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 16:01:59 +0000 Subject: [PATCH 10/13] h --- onnx_diagnostic/torch_onnx/sbs_dataclasses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py index c229c294..d202be09 100644 --- a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py +++ b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py @@ -342,7 +342,7 @@ def dump( ) onnx.save(model_inputs, model_inputs_file) - model_outputs_file = os.path.join(folder, "model.inputs.onnx") + model_outputs_file = os.path.join(folder, "model.outputs.onnx") model_outputs = select_model_inputs_outputs( model, outputs=[i.name for i in submodel.graph.output] ) From 207a26bc9d917ec7c5e41157a39107baf7841168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 16:09:11 +0000 Subject: [PATCH 11/13] spell gi --- onnx_diagnostic/helpers/onnx_helper.py | 22 +++++++++---------- onnx_diagnostic/torch_onnx/sbs_dataclasses.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index fe89d31d..3cdeea47 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1418,18 +1418,18 @@ def _enumerate_model_node_outputs( assert hasattr(model, "graph"), "Parameter model is not an ONNX model but {type(model)}" if order: edges = [] - dorder = {} + d_order = {} node_names = {} for inp in model.graph.input: - dorder[0, inp.name] = 0 + d_order[0, inp.name] = 0 for node in model.graph.node: - dorder[1, node.name] = 0 + d_order[1, node.name] = 0 for i in node.input: edges.append(("in", i, node.name)) for o in node.output: edges.append(("out", o, node.name)) node_names[o] = node - dorder[0, o] = 0 + d_order[0, o] = 0 modif = 1 n_iter = 0 @@ -1438,17 +1438,17 @@ def _enumerate_model_node_outputs( n_iter += 1 for kind, data_name, node_name in edges: if kind == "in": - if (0, data_name) not in dorder: + if (0, data_name) not in d_order: continue - if dorder[0, data_name] + 1 > dorder[1, node_name]: + if d_order[0, data_name] + 1 > d_order[1, node_name]: modif += 1 - dorder[1, node_name] = dorder[0, data_name] + 1 + d_order[1, node_name] = d_order[0, data_name] + 1 else: - if dorder[1, node_name] + 1 > dorder[0, data_name]: + if d_order[1, node_name] + 1 > d_order[0, data_name]: modif += 1 - dorder[0, data_name] = dorder[1, node_name] + 1 + d_order[0, data_name] = d_order[1, node_name] + 1 - orders = [(v, k) for k, v in dorder.items()] + orders = [(v, k) for k, v in d_order.items()] orders.sort() for _, k in orders: @@ -1478,7 +1478,7 @@ def onnx_remove_node_unused( is_function = isinstance(graph, FunctionProto) # mark outputs - marked = ( + marked: Dict[str, Set[str]] = ( {o: set() for o in graph.output} if is_function else {o.name: set() for o in graph.output} diff --git a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py index d202be09..16312123 100644 --- a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py +++ b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py @@ -66,7 +66,7 @@ class ReplayConfiguration: :param selected_names: list of results names to dump :param selected_op_types: list of onnx operators to dump :param threshold: only keep those whose discrepancies is greater than that threshold - :param dump_prefix_model: aftrer dumping the smallest model able to replicate + :param dump_prefix_model: after dumping the smallest model able to replicate one given output, if also dumps the models producing the inputs and the outputs truncated from the big one """ From fe4ef58ee37f9897c3012d41055f5b426bca5fdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 17:31:22 +0000 Subject: [PATCH 12/13] improves a few things --- .../ut_torch_onnx/data/attention_loopmha.onnx | Bin 0 -> 7440 bytes onnx_diagnostic/helpers/onnx_helper.py | 14 +++++++------- onnx_diagnostic/helpers/torch_helper.py | 3 ++- onnx_diagnostic/torch_onnx/sbs.py | 12 +++++++++--- onnx_diagnostic/torch_onnx/sbs_dataclasses.py | 7 +++++-- 5 files changed, 23 insertions(+), 13 deletions(-) create mode 100644 _unittests/ut_torch_onnx/data/attention_loopmha.onnx diff --git a/_unittests/ut_torch_onnx/data/attention_loopmha.onnx b/_unittests/ut_torch_onnx/data/attention_loopmha.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cba6a71e5bdf3f5d5fa266161856b8246c59882a GIT binary patch literal 7440 zcmdT}OK%(373NT+NF90|m_Ub%Vv-mU7*L#ws}+Yj6O zZ`FX_>?b;9$QY(`HqdQ$CD3%F5%rE-R`^sdBc2H?TvIk#;(w9Rrj`S?{<6(zfHpRF(Z4j zT(7s~YQ44FY_!|W-D(@ce6TVn2-zTtr`@`8DaxE&n!U&GtmGi2J~8aWAG0g=lS zA`1Zm~#gESPh@YR^=I6Q1A8c#Q?aVVRm>^-YL80xkijI--88`j@ifLr$J#Du&b zX`JGi9eM%lE8>PM1iS(tCLz`qmhX-fTYuccWTn(rlICtHbx*aO${ZveYk8Ko%7CKf z|M`!!6eqNJ=_$-dUG7{Hd>gd4GmVQcB?wr}(oTPXX*X%^l~TvX`}*@>F`5^)xYd|u z@q+06WO)-lmRJtUNnKNTK1TDZ6r`r@sHRg{1K_zD^QmLsN#O6-{*Uxj%E-lisy;Mz zyS`7(N2lZ_PD#X}%M1RX2||5HmFruZomD7c$q;G;V1Zc2zwO#4d?G<^WZGEx^!Ih8 z!HW25c&0c*)wBnet!rpOtu*)0e_vUI)IcBh5qk2DbZ3T8r#z%$_`G)X;O@g9Tz6w2h8o~f5IZ@Gnk z>;$o(aK?Q#66cJd`aox#V;6}E3Z+*>pIjsFMQ)JEcQ-<1x8SN0Kj{7#2)Y)V!Bl|O zG@kP~`-cRCG~^C5MD)4?F%`V&Ay|bVFpG6+ApR< zekn0JQnVT<4ra`nSV3NfG!MASeNTjO$LgGb;0D)V#86G=0 zK&6Yqg;3V8&f&`=;~hOV8TmTP&8RVD_L5K;GG%`L+3ZB1{0g-rR}Pe7CJ0#09QBof z>U2+ao8ZW$nd8xkT=26!KfYnU7xUwH1Zl{!DCQK^aZLK=W~p$D>RLCu`m6ncb5_Yi zTBAeXGS>K{oaKmOYCDSN@#zpnbFi+5CG@{9foL%#V2pu*k49+`o>nP!WLbmC-x#D$ zEbUDGl)=xE=unrgxWjSEP-@!B_jtiPP*DYVRp!t>c9p^p(O7+GZf}z~_>r$4SlktwhScyX7Ctu$1497Dze8A@XC~9E7 zkOi%F)8A=Vai>jY1@_vvmS>T0{#b&^+Q4~d^J7Tv=At_PjC<=GCc|(#FNJtqARa}+ z{v^2|Aq0};kk#G@k24fGRLf`XY8q}ov2*fl{`a~WvHp+X7D1y8@R!i>!=pNlVteQy z-ezg`b=5i5hrxqnp76h}-fB}EN8#6HTsuCY;Zb5;2;Zf^$&gq*Cz2zUggx;4uq7h% zY(_VU{?>D14*W}kv0w$4Z6J&=&yy5(hUhF*ms!S@wFZK3`+c{?K=gPn5Zi5n!1Y}* zxyA|JeZeI==jvC_J;6()i}7~}Jf$(!ur#R=GHdt6FpE6qYuc%XW!zb0+TM#{+CNdo zcwq*vTfjGhLH|@$+%qE;&1#9gIqEsa8@j69#ad@NhGkZYFh?@u6bJ3)B3ARt0#ejf zd-~;_|9<&{r;^0WyNsbK?-#mOf2VJBhn8)u3Bl{7bl0<=N>w8o|kB^*@P)>X@YZ(zKvt-gv1HTw{hD#9gpb= zWU!f`;r3=({sPG*1XyDa@vQF{&Y4`K4SuNIBEopH3%~!F+lvhJC)lDe`uHV}7RNe= z%|a@hVF?=Yfzi+$4$X`7vBU|3F=T!l&N&h1G#R$D(*Y?L*HR>Zg_#R@lxrG3E3EH+ zE_$@ZawIN=@NAdZJ>Fk`ZD{p3lpe}W+)m`}-RKD{9AGi^(~Y@6i=Tq5O#PH^+!w-UE5TMC3O!>LbG#OH1l42n204}klNo`TVbKM8j8?!!@5VsAi^%?nLFoERCYMMmVv zNIkfGlf^FIj4Y2Mq2+NTv^IW%E6t5f Set[str]: """ Returns the hidden inputs (inputs coming from an upper context) - used by a subgraph. + used by a subgraph. It excludes empty names. """ hidden = set() memo = ( @@ -1223,7 +1223,7 @@ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: ) for node in graph.node: for i in node.input: - if i not in memo: + if i and i not in memo: hidden.add(i) for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH and att.g: @@ -1237,9 +1237,9 @@ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]: """ Returns input and hidden inputs of a node. - See :func:`get_hidden_inputs`. + See :func:`get_hidden_inputs`. It excludes empty names. """ - start = set(node.input) + start = {i for i in node.input if i} if node.op_type in {"Scan", "Loop", "If"}: for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: @@ -1489,7 +1489,7 @@ def onnx_remove_node_unused( for node in reversed(nodes): used = False for o in node.output: - if o in marked: + if o and o in marked: for i in get_all_node_inputs(node): marked[o].add(i) used = True @@ -1501,7 +1501,7 @@ def onnx_remove_node_unused( removed = set() marked_set = set(marked) for ind, node in enumerate(nodes): - if not (set(node.output) & marked_set): + if not ({o for o in node.output if o} & marked_set): removed.add(ind) if not is_function: @@ -1592,7 +1592,7 @@ def select_model_inputs_outputs( for inp in inputs: mark_var[inp] = 0 for out in outputs: - assert out in mark_var, "Output '{out}' not found in model." + assert out in mark_var, f"Output {out!r} not found in model." mark_var[out] = 1 nodes = list(model.graph.node[::-1]) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 11fc05a1..f7c0bbfc 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -811,7 +811,8 @@ def torch_deepcopy(value: Any) -> Any: if isinstance(value, tuple): return tuple(torch_deepcopy(v) for v in value) if isinstance(value, list): - return [torch_deepcopy(v) for v in value] + if type(value) is list: + return [torch_deepcopy(v) for v in value] if isinstance(value, set): return {torch_deepcopy(v) for v in value} if isinstance(value, dict): diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 4bdfed0a..c6018e90 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -7,7 +7,13 @@ import torch from ..helpers import string_type, string_diff, max_diff, flatten_object from ..helpers.onnx_helper import pretty_onnx -from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype +from ..helpers.torch_helper import ( + to_numpy, + from_numpy, + to_tensor, + torch_dtype_to_onnx_dtype, + torch_deepcopy, +) from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator from .sbs_dataclasses import ( @@ -188,7 +194,7 @@ def _loop_onnx_node( print(f"[run_aligned] feeds={string_type(feeds, **str_kws)}") begin = time.perf_counter() try: - res = ref.run(None, feeds) # type: ignore[attr-defined] + res = ref.run(None, torch_deepcopy(feeds)) # type: ignore[attr-defined] except Exception as e: raise RuntimeError( f"Unable to run node {node.op_type}, domain={node.domain} " @@ -241,7 +247,7 @@ def _loop_onnx_node( f"[run_aligned] feeds for second run=" f"{string_type(new_feeds, **str_kws)}" ) - cross = ref.run(None, new_feeds) + cross = ref.run(None, torch_deepcopy(new_feeds)) if verbose > 1: print(f"[run_aligned] got for second run={string_type(cross, **str_kws)}") # Gemm = torch.nn.function.linear, in that case, we just run it as well diff --git a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py index 16312123..b0ef5a5e 100644 --- a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py +++ b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py @@ -306,7 +306,7 @@ def dump( del submodel.graph.input[:] submodel.graph.input.extend(new_inputs) if verbose: - print(f"[ReplayConfiguration.dump] removed input {removed_inputs}") + print(f"[ReplayConfiguration.dump] removed inputs {removed_inputs}") print(f"[ReplayConfiguration.dump] final model inputs {input_names}") onnx.save(submodel, os.path.join(folder, "model.onnx")) @@ -337,8 +337,11 @@ def dump( torch.save(main_inputs, os.path.join(folder, "onnx_main_inputs.pt")) model_inputs_file = os.path.join(folder, "model.inputs.onnx") + exclude = {i.name for i in model.graph.input} | { + i.name for i in model.graph.initializer + } model_inputs = select_model_inputs_outputs( - model, inputs=[i.name for i in submodel.graph.input] + model, outputs=[i.name for i in submodel.graph.input if i.name not in exclude] ) onnx.save(model_inputs, model_inputs_file) From b15806b7a22e455654d4ab62d2c08c1898c0f473 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 23:05:04 +0100 Subject: [PATCH 13/13] mon --- CHANGELOGS.rst | 1 + _unittests/ut_helpers/test_onnx_helper.py | 98 +++++++++++++++++++ .../export/control_flow_research.py | 15 ++- 3 files changed, 109 insertions(+), 5 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 30fd0e48..dba062c7 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.4 +++++ +* :pr:`338`: fixes ReplayConfiguration.dump, add function to select of part of a model * :pr:`337`: fixes extract_subset_of_nodes * :pr:`336`: implements versioned onnx plugs diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 9a9242e8..9b2cebea 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -7,8 +7,14 @@ import onnx.numpy_helper as onh from onnx import TensorProto, FunctionProto, ValueInfoProto from onnx.checker import check_model +from onnx.external_data_helper import ( + load_external_data_for_model, + _get_all_tensors, + uses_external_data, +) import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.reference import ExtendedReferenceEvaluator from onnx_diagnostic.helpers.onnx_helper import ( onnx_lighten, onnx_unlighten, @@ -23,6 +29,7 @@ onnx_dtype_name, extract_subset_of_nodes, make_submodel, + select_model_inputs_outputs, ) @@ -570,6 +577,97 @@ def test_extract_subset_of_nodes_bigger(self): [n.op_type for n in nodes], ) + def _get_model_select(self): + X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Z = oh.make_tensor_value_info("Z", TensorProto.INT64, [None, None]) + graph = oh.make_graph( + [ + oh.make_node("Mul", ["X", "X"], ["X2"]), + oh.make_node("Add", ["X2", "Y"], ["z1"]), + oh.make_node("Mul", ["z1", "W"], ["z2"]), + oh.make_node("Cast", ["z2"], ["Z"], to=TensorProto.INT64), + ], + "add", + [X], + [Z], + [ + onh.from_array(np.arange(16).reshape((-1, 4)).astype(np.float32), name="Y"), + onh.from_array( + (np.arange(16).reshape((-1, 4)) + 100).astype(np.float32), name="W" + ), + ], + ) + onnx_model = oh.make_model( + graph, opset_imports=[oh.make_opsetid("", 18)], ir_version=8 + ) + return onnx_model + + def test_select_model_inputs_outputs(self): + def enumerate_model_tensors(model): + for tensor in _get_all_tensors(model): + yield tensor, uses_external_data(tensor) + + model = self._get_model_select() + root = self.get_dump_folder("test_select_model_inputs_outputs") + name = os.path.join(root, "model_ext.onnx") + location = os.path.basename(name) + ".data" + onnx.save( + model, name, save_as_external_data=True, size_threshold=15, location=location + ) + self.assertEqual( + list(sorted(os.listdir(root))), + ["model_ext.onnx", "model_ext.onnx.data"], + ) + + # X + name2 = os.path.join(root, "sub_model_ext.onnx") + model2 = onnx.load(name, load_external_data=False) + new_model = select_model_inputs_outputs(model2, outputs=["X2"]) + onnx.save(new_model, name2) + + x = np.arange(16).reshape((-1, 4)).astype(np.float32) + y = np.arange(16).reshape((-1, 4)).astype(np.float32) + + sess = ExtendedReferenceEvaluator(new_model) + got = sess.run(None, {"X": x})[0] + self.assertEqual((x**2).tolist(), got.tolist()) + + sess = ExtendedReferenceEvaluator(name2) + got = sess.run(None, {"X": x})[0] + self.assertEqual((x**2).tolist(), got.tolist()) + + # z1 + name3 = os.path.join(root, "sub_model_ext_z1.onnx") + model2 = onnx.load(name, load_external_data=False) + new_model = select_model_inputs_outputs(model2, outputs=["z1"]) + onnx.save(new_model, name3) + self.assertEqual( + [ + "model_ext.onnx", + "model_ext.onnx.data", + "sub_model_ext.onnx", + "sub_model_ext_z1.onnx", + ], + list(sorted(os.listdir(root))), + ) + + x = np.arange(16).reshape((-1, 4)).astype(np.float32) + + sess = ExtendedReferenceEvaluator(name3) + got = sess.run(None, {"X": x})[0] + self.assertEqual((x**2 + y).tolist(), got.tolist()) + + tensors = list(enumerate_model_tensors(new_model)) + self.assertEqual(len(tensors), 1) + self.assertIsInstance(tensors[0], tuple) + self.assertEqual(len(tensors[0]), 2) + self.assertTrue(tensors[0][-1]) + self.assertIsInstance(tensors[0][0], TensorProto) + load_external_data_for_model(new_model, root) + sess = ExtendedReferenceEvaluator(new_model) + got = sess.run(None, {"X": x})[0] + self.assertEqual((x**2 + y).tolist(), got.tolist()) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/export/control_flow_research.py b/onnx_diagnostic/export/control_flow_research.py index 261d0a5a..c10d9b07 100644 --- a/onnx_diagnostic/export/control_flow_research.py +++ b/onnx_diagnostic/export/control_flow_research.py @@ -92,10 +92,11 @@ def _loop_for_op_wrapper(*args, **kwargs): from torch._higher_order_ops.utils import setup_compilation_env - with setup_compilation_env() as backend: - return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)( - n_iter, body_fn, operands - ) + with setup_compilation_env() as _backend: + return _loop_for_op_wrapper(n_iter, body_fn, *operands) + # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)( + # n_iter, body_fn, operands + # ) def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands): @@ -127,9 +128,13 @@ def loop_for_op_dense(n_iter, body_fn, operands): ), f"Dense implementation operands must be a list of tensors and ints {operands}" mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" - return _loop_for_onnx_fn(body_fn, n_iter, None, *operands) + return _loop_for_onnx_fn(body_fn, n_iter, None, operands) @simple_loop_for_op.py_impl(ProxyTorchDispatchMode) def inner(mode, n_iter, body_fn, operands): return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands) + + +simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU) +simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)