From c6ddf5d71213c18dc9df561385b593062fc16a1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 10:34:01 +0100 Subject: [PATCH 1/3] improves dot --- onnx_diagnostic/helpers/dot_helper.py | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index a1d6d1a9..3db05769 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -1,6 +1,8 @@ from typing import Dict, Set +import numpy as np import onnx import onnx.numpy_helper as onh +from ..reference import ExtendedReferenceEvaluator as Inference from .onnx_helper import onnx_dtype_name, pretty_onnx @@ -142,6 +144,7 @@ def _mkn(obj: object) -> int: inits = list(model.graph.initializer) tiny_inits = {} name_to_ids = {} + for inp in inputs: if not inp.name: continue @@ -149,7 +152,29 @@ def _mkn(obj: object) -> int: rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];') name_to_ids[inp.name] = f"I_{_mkn(inp)}" edge_label[inp.name] = _make_edge_label(inp, multi_line=True) + + # Small constant --> initializer + for node in nodes: + if node.op_type != "Constant": + continue + skip = False + for att in node.attribute: + if att.name == "value" and ( + len(att.t.dims) > 1 or np.prod(tuple(att.t.dims)) > 10 + ): + skip = True + break + if skip: + continue + + sess = Inference(node) + value = sess.run(None, {})[0] + inits.append(onh.from_array(value, name=node.output[0])) + for init in inits: + if init.name in inputs: + # hide optional inputs + continue shape = tuple(init.dims) if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10): a = onh.to_array(init) @@ -161,7 +186,10 @@ def _mkn(obj: object) -> int: rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];') name_to_ids[init.name] = f"i_{_mkn(init)}" edge_label[init.name] = ls + for node in nodes: + if node.op_type == "Constant" and node.output[0] in name_to_ids: + continue color = op_type_colors.get(node.op_type, "#cccccc") label = _make_node_label(node, tiny_inits) rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];') From c0e2e46776bc77bcc0079156727fccfa8e0bfff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 11:09:04 +0100 Subject: [PATCH 2/3] improves dot rendering --- _unittests/ut_helpers/test_dot_helper.py | 57 ++++++++++++++++++++++++ onnx_diagnostic/helpers/dot_helper.py | 11 +++-- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/_unittests/ut_helpers/test_dot_helper.py b/_unittests/ut_helpers/test_dot_helper.py index 5b250bc5..3c885149 100644 --- a/_unittests/ut_helpers/test_dot_helper.py +++ b/_unittests/ut_helpers/test_dot_helper.py @@ -62,6 +62,63 @@ def test_custom_doc_kernels_layer_normalization(self): self.maxDiff = None self.assertEqual(expected.strip("\n "), dot.strip("\n ")) + def test_custom_doc_kernels_layer_normalization_constant(self): + TFLOAT16 = onnx.TensorProto.FLOAT16 + model = oh.make_model( + oh.make_graph( + [ + oh.make_node( + "LayerNormalization", + ["X", "W", "B"], + ["ln"], + axis=-1, + epsilon=9.999999974752427e-7, + ), + oh.make_node("Constant", [], ["cst"], value_float=[1]), + oh.make_node("Cast", ["cst"], ["cst16"], to=onnx.TensorProto.FLOAT16), + oh.make_node("Add", ["ln", "cst16"], ["Z"], axis=-1), + ], + "dummy", + [ + oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]), + oh.make_tensor_value_info("W", TFLOAT16, ["d"]), + oh.make_tensor_value_info("B", TFLOAT16, ["d"]), + ], + [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])], + ), + ir_version=9, + opset_imports=[oh.make_opsetid("", 18)], + ) + dot = to_dot(model) + expected = ( + textwrap.dedent( + """ + digraph { + graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8]; + node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box]; + edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0]; + I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"]; + I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"]; + I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"]; + LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"]; + Cast_4 [label="Cast([1.0], to=FLOAT16)", fillcolor="#cccccc"]; + Add_5[label="Add(.,.,axis=-1)",fillcolor="#cccccc"]; + I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"]; + I_1 -> LayerNormalization_3 [label="FLOAT16(d)"]; + I_2 -> LayerNormalization_3 [label="FLOAT16(d)"]; + LayerNormalization_3 -> Add_5 [label="FLOAT16(b,c,d)"]; + Cast_4->Add_5[label="FLOAT16()"]; + O_6 [label="Z\\nFLOAT16(b,c,d)", fillcolor="#aaaaee"]; + Add_5 -> O_6; + } + """ + ) + .strip("\n") + .replace(" ", "") + ) + self.maxDiff = None + self.assertEqual(expected, dot.strip("\n").replace(" ", "")) + @requires_transformers("4.57") def test_dot_plot_tiny(self): data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index 3db05769..f2340210 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -27,7 +27,7 @@ def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str: - els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "("] + els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "\\n("] ee = [tiny_inits.get(i, ".") if i else "" for i in node.input] for att in node.attribute: if att.name == "to": @@ -44,7 +44,10 @@ def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str: els.append(")") if node.op_type == "Constant": els.extend([" -> ", node.output[0]]) - return "".join(els) + res = "".join(els) + if len(res) < 40: + return res.replace("\\n(", "(") + return res def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str: @@ -172,7 +175,7 @@ def _mkn(obj: object) -> int: inits.append(onh.from_array(value, name=node.output[0])) for init in inits: - if init.name in inputs: + if init.name in name_to_ids: # hide optional inputs continue shape = tuple(init.dims) @@ -188,7 +191,7 @@ def _mkn(obj: object) -> int: edge_label[init.name] = ls for node in nodes: - if node.op_type == "Constant" and node.output[0] in name_to_ids: + if node.op_type == "Constant" and node.output[0] in tiny_inits: continue color = op_type_colors.get(node.op_type, "#cccccc") label = _make_node_label(node, tiny_inits) From 1b9718b170fc301938185a661992580b473d2793 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Dec 2025 13:43:06 +0100 Subject: [PATCH 3/3] fix last bug --- onnx_diagnostic/helpers/dot_helper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index f2340210..360e3910 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -157,8 +157,9 @@ def _mkn(obj: object) -> int: edge_label[inp.name] = _make_edge_label(inp, multi_line=True) # Small constant --> initializer + output_names = {n.name for n in outputs} for node in nodes: - if node.op_type != "Constant": + if node.op_type != "Constant" or node.output[0] in output_names: continue skip = False for att in node.attribute: