From 6af4693139cb47689ed855901217b369c66e4f8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 09:58:09 +0000 Subject: [PATCH 01/11] fix a few inefficiencies in sbs --- onnx_diagnostic/torch_onnx/sbs.py | 48 +++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index a2f814bf..959e8aef 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -2,7 +2,7 @@ import os import time from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Self, Set, Tuple, Union import onnx import onnx.helper as oh import numpy as np @@ -371,6 +371,29 @@ def set_diff(self, diff: Dict[str, Any]): self.err_nan = diff["nan"] if "rep" in diff: self.err_h01 = diff["rep"][">0.1"] + return self + + @property + def key(self) -> Tuple[int, int, int, str, str]: + "Creates a unique identifier." + return ( + self.ep_id_node, + self.onnx_id_node, + self.onnx_id_output, + self.ep_name, + self.onnx_name, + ) + + def check(self, already_yielded: Dict[Tuple[int, int, int, str, str], int]) -> Self: + "Checks a record was not already yielded." + key = self.key + assert key not in already_yielded, ( + f"Record with key={key} was already yielded, " + f"number of records={len(already_yielded)} and previous " + f"record at position {already_yielded[key]} (self={self})" + ) + already_yielded[key] = len(already_yielded) + return self @dataclass @@ -451,8 +474,6 @@ def run_aligned( for the onnx runtime :param atol: absolute tolerance :param rtol: relative tolerance - :param gemmlinear: if True, replaces ``Gemm(A,X.T,B)`` by - ``torch.nn.functional.linear(A,X,B)`` on onnx side :param verbose: verbosity level :param exc: stops if an exception :param reset_names: list of names, the onnx execution takes the torch outputs instead @@ -595,6 +616,7 @@ def forward(self, x): -v 1 --atol=0.1 --rtol=1 """ assert callable(run_cls), f"run_cls={run_cls} not a callable" + already_yielded = {} reset_names = set(reset_names) if reset_names else set() # type: ignore[assignment] str_kws = dict(with_shape=True, with_device=True) has_cuda = any( @@ -774,7 +796,7 @@ def _loop_onnx_node( list_node_output = list(node.output) node_output = [o for o in list_node_output if o] for o, r in zip(node_output, res): - if r is None or o is None: + if r is None or not o: continue tmp = _loop_cmp( mapping_onnx_to_torch, @@ -1033,7 +1055,7 @@ def _duplicated_values(d): onnx_name=init.name, onnx_op_type="initializer", onnx_shape_type=string_type(t, **str_kws), - ) + ).check(already_yielded) size = t.element_size() * t.numel() if t.is_cuda: @@ -1115,7 +1137,7 @@ def _duplicated_values(d): onnx_results[torch_names_to_onnx_names[node.name]], **str_kws ), ) - yield record + yield record.check(already_yielded) else: assert node.name in placeholders_to_state_dict, ( f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), " @@ -1155,7 +1177,7 @@ def _duplicated_values(d): hist=[0.1], ) ) - yield record + yield record.check(already_yielded) else: if verbose > 1: print( @@ -1166,7 +1188,7 @@ def _duplicated_values(d): ep_name=node.name, ep_target="placeholder", ep_shape_type=string_type(t, **str_kws), - ) + ).check(already_yielded) continue outputs = [node.name] if isinstance(node.name, str) else list(node.name) @@ -1197,6 +1219,8 @@ def _duplicated_values(d): continue for i_onnx in range(last_position, max_pos + 1): + if i_onnx in already_run: + continue for r in _loop_onnx_node( onx, ep_graph_nodes, @@ -1216,7 +1240,7 @@ def _duplicated_values(d): verbose, ): if r: - yield r + yield r.check(already_yielded) last_position = max_pos + 1 @@ -1227,6 +1251,8 @@ def _duplicated_values(d): f"to {len(onx.graph.node)}" ) for i_onnx in range(last_position, len(onx.graph.node)): + if i_onnx in already_run: + continue for r in _loop_onnx_node( onx, ep_graph_nodes, @@ -1246,9 +1272,7 @@ def _duplicated_values(d): verbose, ): if r: - yield r - - already_run.add(i_onnx) + yield r.check(already_yielded) if verbose: print(f"[run_aligned] done with status={status.to_str()}") From 389578a8f7b881fd14250d7b5f387cd2669c452c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 10:36:45 +0000 Subject: [PATCH 02/11] fix --- CHANGELOGS.rst | 2 +- onnx_diagnostic/_command_lines_parser.py | 2 +- onnx_diagnostic/torch_onnx/sbs.py | 45 +++++++++++++++++------- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index f8c85340..bce920d9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -8,7 +8,7 @@ Change Logs * :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime * :pr:`310`: splits patches into multiple files * :pr:`308`: add option --save_ep to dump the exported program as well as torch input -* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`: improves side-by-side comparison, creates command line sbs +* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`, :pr:`319`: improves side-by-side comparison, creates command line sbs 0.8.2 +++++ diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index c67fc770..182dcdd5 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1234,7 +1234,7 @@ def get_parser_sbs() -> ArgumentParser: "--replay-threshold", type=float, required=False, - default=1e6, + default=1e9, help="Triggers the replay if the discrepancies are higher than this value.", ) parser.add_argument( diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 959e8aef..b4b8034e 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -214,13 +214,11 @@ def select( :param err_abs: measured discrepancy :return: True if this should be dumped """ - if name and self.selected_names: - if name in self.selected_names: - return True - if op_type and self.selected_op_types: - if op_type in self.selected_op_types: - return True - if err_abs is not None and err_abs >= self.threshold: + if name and self.selected_names and name in self.selected_names: + return True + if op_type and self.selected_op_types and op_type in self.selected_op_types: + return True + if err_abs is not None and self.threshold is not None and err_abs >= self.threshold: return True return False @@ -286,15 +284,23 @@ def dump( if n in onnx_name_to_ep_name: torch_inputs[n] = torch_results[onnx_name_to_ep_name[n]] else: - raise AssertionError(f"n={n!r}, onnx_name_to_ep_name={onnx_name_to_ep_name}") + # It is possible that this result only exists in the onnx worlds. + pass onnx_inputs = {n: onnx_results[n] for n in input_names} assert ( name in onnx_name_to_ep_name ), f"Unable to find {name!r} in {onnx_name_to_ep_name}" - expected_outputs = (torch_results[onnx_name_to_ep_name[name]],) + expected_outputs_and_mapping = dict( + expected=(torch_results[onnx_name_to_ep_name[name]],), + mapping={ + k: onnx_name_to_ep_name[k] for k in input_names if k in onnx_name_to_ep_name + }, + ) torch.save(torch_inputs, os.path.join(folder, "torch_inputs.pt")) torch.save(onnx_inputs, os.path.join(folder, "onnx_inputs.pt")) - torch.save(expected_outputs, os.path.join(folder, "torch_outputs.pt")) + torch.save( + expected_outputs_and_mapping, os.path.join(folder, "torch_outputs_and_mapping.pt") + ) if verbose: print(f"[ReplayConfiguration.dump] done {folder!r}") return folder @@ -374,7 +380,9 @@ def set_diff(self, diff: Dict[str, Any]): return self @property - def key(self) -> Tuple[int, int, int, str, str]: + def key( + self, + ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]]: "Creates a unique identifier." return ( self.ep_id_node, @@ -384,8 +392,17 @@ def key(self) -> Tuple[int, int, int, str, str]: self.onnx_name, ) - def check(self, already_yielded: Dict[Tuple[int, int, int, str, str], int]) -> Self: + def check( + self, + already_yielded: Dict[ + Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]], + int, + ], + ) -> Self: "Checks a record was not already yielded." + if self.onnx_op_type == "reset": + # no record for this one + return self key = self.key assert key not in already_yielded, ( f"Record with key={key} was already yielded, " @@ -616,7 +633,7 @@ def forward(self, x): -v 1 --atol=0.1 --rtol=1 """ assert callable(run_cls), f"run_cls={run_cls} not a callable" - already_yielded = {} + already_yielded = {} # type: ignore[var-annotated] reset_names = set(reset_names) if reset_names else set() # type: ignore[assignment] str_kws = dict(with_shape=True, with_device=True) has_cuda = any( @@ -647,6 +664,8 @@ def forward(self, x): if verbose: print(f"[run_aligned] run_cls={run_cls}") print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}") + if replay_configuration: + print(f"[run_aligned] replay={replay_configuration}") def _check_tensor_(name, obj, flip_type=False): if flip_type: From 85894da713dd16ceb03e47fca2bbd1be98d16c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 11:08:34 +0000 Subject: [PATCH 03/11] better doc --- onnx_diagnostic/torch_onnx/sbs.py | 91 ++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 7 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index b4b8034e..ed4162f6 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -1,14 +1,20 @@ import inspect import os +import textwrap import time from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterator, List, Optional, Self, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union import onnx import onnx.helper as oh import numpy as np import torch from ..helpers import string_type, string_diff, max_diff, flatten_object -from ..helpers.onnx_helper import pretty_onnx, extract_subset_of_nodes, make_submodel +from ..helpers.onnx_helper import ( + pretty_onnx, + extract_subset_of_nodes, + make_submodel, + from_array_extended, +) from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype @@ -222,6 +228,62 @@ def select( return True return False + def get_replay_code(self) -> str: + """ + Returns a code letting the user replay the onnx model. + It looks like the following. + + .. runpython:: + :showcode: + + from onnx_diagnostic.torch_onnx.sbs import ReplayConfiguration + + rc = ReplayConfiguration(dump_folder="unsued") + print(rc.get_replay_code()) + """ + return textwrap.dedent( + """ + import onnx + import torch + from onnx_diagnostic.helpers import max_diff, string_diff, string_type + from onnx_diagnostic.helpers.onnx_helper import pretty_onnx + from onnx_diagnostic.reference import OnnxruntimeEvaluator + + skws = dict(with_shape=True, with_device=True) + + torch_inputs = torch.load("torch_inputs.pt") + onnx_inputs = torch.load("onnx_inputs.pt") + expected_outputs_and_mapping = torch.load("torch_outputs_and_mapping.pt") + expected = expected_outputs_and_mapping["expected"] + mapping = expected_outputs_and_mapping["mapping"] + + print(f"-- torch_inputs={string_type(torch_inputs, **skws)}") + print(f"-- onnx_inputs={string_type(onnx_inputs, **skws)}") + print(f"-- expected={string_type(expected, **skws)}") + print(f"-- mapping={mapping}") + + model = onnx.load("model.onnx") + print("-- model.onnx") + print(pretty_onnx(model)) + print("--") + + print("-- run with onnx_inputs") + sess = OnnxruntimeEvaluator(model, whole=True) + feeds = onnx_inputs + obtained = sess.run(None, feeds) + print(f"-- obtained={string_type(obtained, **skws)}") + diff = max_diff(expected, tuple(obtained)) + print(f"-- diff: {string_diff(diff)}") + + print("-- run with torch_inputs") + feeds = {k: torch_inputs[mapping[k]] for k in feeds} + obtained = sess.run(None, feeds) + print(f"-- obtained={string_type(obtained, **skws)}") + diff = max_diff(expected, tuple(obtained)) + print(f"-- diff: {string_diff(diff)}") + """ + ) + def dump( self, name: str, @@ -278,14 +340,27 @@ def dump( os.makedirs(folder, exist_ok=True) if verbose: print(f"[ReplayConfiguration.dump] dumps into folder {folder!r}") - onnx.save(submodel, os.path.join(folder, "model.onnx")) + torch_inputs = {} + removed_inputs = set() for n in input_names: if n in onnx_name_to_ep_name: torch_inputs[n] = torch_results[onnx_name_to_ep_name[n]] else: - # It is possible that this result only exists in the onnx worlds. - pass + # We add that input as an initializer because it is probably a constant. + submodel.graph.initializer.append(from_array_extended(onnx_results[n], name=n)) + removed_inputs.add(n) + + if removed_inputs: + input_names = [i for i in input_names if i not in removed_inputs] + new_inputs = [i for i in submodel.graph.input if i.name not in removed_inputs] + del submodel.graph.input[:] + submodel.graph.input.extend(new_inputs) + if verbose: + print(f"[ReplayConfiguration.dump] removed input {removed_inputs}") + print(f"[ReplayConfiguration.dump] final model inputs {input_names}") + + onnx.save(submodel, os.path.join(folder, "model.onnx")) onnx_inputs = {n: onnx_results[n] for n in input_names} assert ( name in onnx_name_to_ep_name @@ -301,6 +376,8 @@ def dump( torch.save( expected_outputs_and_mapping, os.path.join(folder, "torch_outputs_and_mapping.pt") ) + with open(os.path.join(folder, "replay.py"), "w") as f: + f.write(self.get_replay_code()) if verbose: print(f"[ReplayConfiguration.dump] done {folder!r}") return folder @@ -363,7 +440,7 @@ def __post_init__(self): f"ep_id_node={self.ep_id_node}" ) - def set_diff(self, diff: Dict[str, Any]): + def set_diff(self, diff: Dict[str, Any]) -> "Self": # noqa: F821 """Sets error.""" if diff is None: return @@ -398,7 +475,7 @@ def check( Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]], int, ], - ) -> Self: + ) -> "Self": # noqa: F821 "Checks a record was not already yielded." if self.onnx_op_type == "reset": # no record for this one From 76a42ac25834bdd95b17bf1f6d34d0ce4407fdc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 11:09:16 +0000 Subject: [PATCH 04/11] to avoid any wrong commit --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 900e2ae8..0cc02aeb 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,7 @@ _cache/* .coverage dist/* build/* +_sbs_* .eggs/* .olive-cache/* .hypothesis/* From b11c81b761bc206b53daa89d1f2aba44dfeb589d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 11:36:20 +0000 Subject: [PATCH 05/11] sbs --- onnx_diagnostic/torch_onnx/sbs.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index ed4162f6..e55493ab 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -231,7 +231,7 @@ def select( def get_replay_code(self) -> str: """ Returns a code letting the user replay the onnx model. - It looks like the following. + It looks like the following. It may have to be adapted. .. runpython:: :showcode: @@ -267,6 +267,20 @@ def get_replay_code(self) -> str: print(pretty_onnx(model)) print("--") + print("-- discrepancies") + ep_feeds = {} + for k, v in onnx_inputs.items(): + tk = mapping.get(k, k) + tkv = torch_inputs[k] if k in torch_inputs else torch_inputs[tk] + ep_feeds[k] = tkv + diff = max_diff(v, tkv) + print( + f"-- {k} -> {tk} ep:{string_type(tkv, **skws)} " + f"nx:{string_type(v, **skws)} / diff {string_diff(diff)}" + ) + + print("-- done.") + print("--") print("-- run with onnx_inputs") sess = OnnxruntimeEvaluator(model, whole=True) feeds = onnx_inputs @@ -274,10 +288,9 @@ def get_replay_code(self) -> str: print(f"-- obtained={string_type(obtained, **skws)}") diff = max_diff(expected, tuple(obtained)) print(f"-- diff: {string_diff(diff)}") - + print("--") print("-- run with torch_inputs") - feeds = {k: torch_inputs[mapping[k]] for k in feeds} - obtained = sess.run(None, feeds) + obtained = sess.run(None, ep_feeds) print(f"-- obtained={string_type(obtained, **skws)}") diff = max_diff(expected, tuple(obtained)) print(f"-- diff: {string_diff(diff)}") From 67af11566c77ba25c2e89ba7d6da8e916c5d3ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 14:10:17 +0000 Subject: [PATCH 06/11] fix disc --- onnx_diagnostic/torch_onnx/sbs.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index e55493ab..a7507b61 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -3,7 +3,7 @@ import textwrap import time from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Self, Set, Tuple, Union import onnx import onnx.helper as oh import numpy as np @@ -238,7 +238,7 @@ def get_replay_code(self) -> str: from onnx_diagnostic.torch_onnx.sbs import ReplayConfiguration - rc = ReplayConfiguration(dump_folder="unsued") + rc = ReplayConfiguration(dump_folder="unused") print(rc.get_replay_code()) """ return textwrap.dedent( @@ -453,7 +453,7 @@ def __post_init__(self): f"ep_id_node={self.ep_id_node}" ) - def set_diff(self, diff: Dict[str, Any]) -> "Self": # noqa: F821 + def set_diff(self, diff: Dict[str, Any]) -> Self: """Sets error.""" if diff is None: return @@ -488,7 +488,7 @@ def check( Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]], int, ], - ) -> "Self": # noqa: F821 + ) -> Self: "Checks a record was not already yielded." if self.onnx_op_type == "reset": # no record for this one @@ -970,13 +970,14 @@ def _loop_onnx_node( torch_results, onnx_results, o, - r, + torch_results[tmp.ep_name], verbose, atol, rtol, i, i_onnx, ) + assert tmp.err_abs == 0, f"Reset did not happen, tmp={tmp}" if tmp is not None: tmp.onnx_op_type = "reset" tmp.onnx_id_output = list_node_output.index(o) From 8715162a60970212d9369a6b893aa85d98d9750e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 15:03:05 +0000 Subject: [PATCH 07/11] style --- onnx_diagnostic/helpers/torch_helper.py | 72 +++++++++++++++++++++++++ onnx_diagnostic/torch_onnx/sbs.py | 19 ++++++- 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index ec220b29..699414a0 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -1,6 +1,7 @@ import contextlib import ctypes import inspect +import math import os import sys import warnings @@ -1003,3 +1004,74 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype: counts[dt] += 1 final = max(list(counts.items())) return final[0] + + +def closest_factor_pair(n: int): + """Tries to find ``a, b`` such as ``n == a * b``.""" + assert n > 0, f"n={n} must be a positive integer" + start = math.isqrt(n) + for a in range(start, 0, -1): + if n % a == 0: + b = n // a + return a, b + return 1, n + + +def study_discrepancies( + t1: torch.Tensor, + t2: torch.Tensor, + bins: int = 50, + figsize: Optional[Tuple[int, int]] = (15, 15), + title: Optional[str] = None, + name: Optional[str] = None, +) -> "Axes": # noqa: F821 + """ + Computes different metrics for the discrepancies. + Returns graphs. + """ + assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}" + assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}" + d1, d2 = ( + (t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32)) + ) + + d1 = d1.squeeze() + d2 = d2.squeeze() + if len(d1.shape) == 1: + new_shape = closest_factor_pair(d1.shape[0]) + d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape) + elif len(d1.shape) > 2: + new_shape = (-1, max(d1.shape)) + d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape) + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(3, 2, figsize=figsize) + vmin, vmax = d1.min().item(), d1.max().item() + ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax) + ax[0, 0].set_title(f"Color plot of the first tensor in\n[{vmin}, {vmax}]") + + diff = d2 - d1 + vmin, vmax = diff.min().item(), diff.max().item() + ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax) + ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]") + + ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins) + ax[1, 0].set_title("Distribution of the first tensor") + + ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins) + ax[1, 1].set_title("Distribution of the differences") + + tf1 = d1.ravel() + td1 = diff.ravel() + ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".") + ax[2, 1].set_title("Graph XY") + ax[2, 1].set_xlabel("First tensor values") + ax[2, 1].set_ylabel("Difference values") + + if title: + fig.suptitle(title) + fig.tight_layout() + if name: + fig.savefig(name) + return ax diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index a7507b61..5ec510e3 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -3,7 +3,13 @@ import textwrap import time from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterator, List, Optional, Self, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union + +try: + from typing import Self +except ImportError: + # python <= 3.10 + Self = "Self" import onnx import onnx.helper as oh import numpy as np @@ -246,6 +252,7 @@ def get_replay_code(self) -> str: import onnx import torch from onnx_diagnostic.helpers import max_diff, string_diff, string_type + from onnx_diagnostic.helpers.torch_helper import study_discrepancies from onnx_diagnostic.helpers.onnx_helper import pretty_onnx from onnx_diagnostic.reference import OnnxruntimeEvaluator @@ -294,6 +301,16 @@ def get_replay_code(self) -> str: print(f"-- obtained={string_type(obtained, **skws)}") diff = max_diff(expected, tuple(obtained)) print(f"-- diff: {string_diff(diff)}") + + print("-- plots") + for i in range(len(expected)): + study_discrepancies( + expected[i], + obtained[i], + title=f"study output {i}", + name=f"disc{i}.png", + bins=50, + ) """ ) From 81bd6647cabdab2d315346298dba8c70f60e1189 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 15:17:06 +0000 Subject: [PATCH 08/11] fxi --- onnx_diagnostic/torch_onnx/sbs.py | 14 ++++++++++---- pyproject.toml | 4 ++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 5ec510e3..40c2121d 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -274,7 +274,13 @@ def get_replay_code(self) -> str: print(pretty_onnx(model)) print("--") - print("-- discrepancies") + print("-- range of inputs") + for k, v in onnx_inputs.items(): + print(f"-- {k}: {string_type(v, **skws, with_min_max=True)}") + print("-- done.") + print("--") + + print("-- discrepancies of inputs") ep_feeds = {} for k, v in onnx_inputs.items(): tk = mapping.get(k, k) @@ -285,21 +291,21 @@ def get_replay_code(self) -> str: f"-- {k} -> {tk} ep:{string_type(tkv, **skws)} " f"nx:{string_type(v, **skws)} / diff {string_diff(diff)}" ) - print("-- done.") print("--") + print("-- run with onnx_inputs") sess = OnnxruntimeEvaluator(model, whole=True) feeds = onnx_inputs obtained = sess.run(None, feeds) print(f"-- obtained={string_type(obtained, **skws)}") - diff = max_diff(expected, tuple(obtained)) + diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01]) print(f"-- diff: {string_diff(diff)}") print("--") print("-- run with torch_inputs") obtained = sess.run(None, ep_feeds) print(f"-- obtained={string_type(obtained, **skws)}") - diff = max_diff(expected, tuple(obtained)) + diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01]) print(f"-- diff: {string_diff(diff)}") print("-- plots") diff --git a/pyproject.toml b/pyproject.toml index 5efe5e6d..f7ce6af7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,10 @@ disable_error_code = ["union-attr"] module = ["onnx_diagnostic.helpers.rt_helper"] disable_error_code = ["arg-type", "assignment", "attr-defined", "call-overload", "misc", "name-defined", "union-attr", "name-defined"] +[[tool.mypy.overrides]] +module = ["onnx_diagnostic.helpers.torch_helper"] +disable_error_code = ["arg-type", "assignment", "attr-defined", "call-overload", "misc", "name-defined", "union-attr", "name-defined"] + [[tool.mypy.overrides]] module = ["onnx_diagnostic.reference.report_results_comparison"] disable_error_code = ["name-defined"] From 1ab2efb3a74d21fc7601f1ee296669083b4fcc70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 17:24:11 +0000 Subject: [PATCH 09/11] changes --- onnx_diagnostic/torch_onnx/sbs.py | 49 +++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 40c2121d..ddb614dd 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -9,7 +9,7 @@ from typing import Self except ImportError: # python <= 3.10 - Self = "Self" + Self = "Self" # type: ignore[assignment] import onnx import onnx.helper as oh import numpy as np @@ -848,7 +848,7 @@ def _loop_cmp( onnx_shape_type=string_type(r, **str_kws), ) r.set_diff(d) - mapping_onnx_to_torch[to] = o + mapping_onnx_to_torch[o] = to return r return None @@ -1035,7 +1035,8 @@ def _duplicated_values(d): if verbose: print(f"[run_aligned] ep: walks through {len(ep.graph.nodes)} nodes from torch") - positions: Dict[str, Any] = {} + # dictionary mapping result names and their position in both graphs. + positions: Dict[str, Dict[str, int]] = {} ep_graph_nodes = list(ep.graph.nodes) torch_results: Dict[str, Any] = {} last_position = 0 @@ -1087,6 +1088,18 @@ def _duplicated_values(d): print(f"[run_aligned] ep: found inputs {torch_input_names}") print(f"[run_aligned] ep: found outputs {torch_output_names}") print(f"[run_aligned] nx: walks through {len(onx.graph.node)} nodes from onnx") + for inp in onx.graph.input: + n = inp.name + if n in positions: + positions[n]["onnx"] = -1 + else: + positions[n] = dict(onnx=-1) + for inp in onx.graph.initializer: + n = inp.name + if n in positions: + positions[n]["onnx"] = -1 + else: + positions[n] = dict(onnx=-1) for i, node in enumerate(onx.graph.node): for n in node.output: if n in positions: @@ -1156,7 +1169,6 @@ def _duplicated_values(d): memory_cpu = 0 memory_cuda = 0 for init in onx.graph.initializer: # type: ignore - positions[init.name] = -1 t = None if init.name in torch_results: if init.name not in skip_mapping_torch_onnx: @@ -1215,6 +1227,8 @@ def _duplicated_values(d): print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}") # starts the side-by-side + if verbose: + print(f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} nodes") if verbose == 1: import tqdm @@ -1222,8 +1236,6 @@ def _duplicated_values(d): else: loop = list(enumerate(ep_graph_nodes)) - if verbose: - print(f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} nodes") already_run: Set[int] = set() ep_durations = {} status = StatusRunAligned() @@ -1345,15 +1357,33 @@ def _duplicated_values(d): max_pos = -2 for n in outputs: - if n in positions and "onnx" in positions[n]: - max_pos = max(max_pos, positions[n]["onnx"]) + if n in positions: + if "onnx" in positions[n]: + max_pos = max(max_pos, positions[n]["onnx"]) + if "fx" in positions[n]: + if positions[n]["fx"] > i: + max_pos = -2 + break if max_pos == -2: # we skip. continue + next_to_visit = last_position for i_onnx in range(last_position, max_pos + 1): if i_onnx in already_run: continue + # The onnx node may produce more than one output, in that + # case, we need to check the exported program is not behind. + node = onx.graph.node[i_onnx] + ep_behind = False + for iname in node.output: + if iname in positions and "fx" in positions[iname]: + if positions[iname]["fx"] > i: + ep_behind = True + break + if ep_behind: + break + for r in _loop_onnx_node( onx, ep_graph_nodes, @@ -1374,8 +1404,9 @@ def _duplicated_values(d): ): if r: yield r.check(already_yielded) + next_to_visit = i_onnx + 1 - last_position = max_pos + 1 + last_position = next_to_visit # complete the execution of the onnx graph if verbose: From 82ed143e55ff06b0b21a93ebc48bd0cafb1063ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 24 Nov 2025 18:36:05 +0000 Subject: [PATCH 10/11] fix --- onnx_diagnostic/helpers/torch_helper.py | 4 +- onnx_diagnostic/torch_onnx/sbs.py | 56 ++++++++++++++++++------- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 699414a0..b79288ba 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -1049,7 +1049,9 @@ def study_discrepancies( fig, ax = plt.subplots(3, 2, figsize=figsize) vmin, vmax = d1.min().item(), d1.max().item() ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax) - ax[0, 0].set_title(f"Color plot of the first tensor in\n[{vmin}, {vmax}]") + ax[0, 0].set_title( + f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}" + ) diff = d2 - d1 vmin, vmax = diff.min().item(), diff.max().item() diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index ddb614dd..23586421 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -269,18 +269,24 @@ def get_replay_code(self) -> str: print(f"-- expected={string_type(expected, **skws)}") print(f"-- mapping={mapping}") - model = onnx.load("model.onnx") + print() print("-- model.onnx") + print() + + model = onnx.load("model.onnx") print(pretty_onnx(model)) - print("--") - print("-- range of inputs") + print() + print("-- range of inputs --") + print() + for k, v in onnx_inputs.items(): print(f"-- {k}: {string_type(v, **skws, with_min_max=True)}") - print("-- done.") - print("--") - print("-- discrepancies of inputs") + print() + print("-- discrepancies of inputs --") + print() + ep_feeds = {} for k, v in onnx_inputs.items(): tk = mapping.get(k, k) @@ -291,24 +297,29 @@ def get_replay_code(self) -> str: f"-- {k} -> {tk} ep:{string_type(tkv, **skws)} " f"nx:{string_type(v, **skws)} / diff {string_diff(diff)}" ) - print("-- done.") - print("--") - print("-- run with onnx_inputs") + print() + print("-- SVD --") + print() + + for k, v in onnx_inputs.items(): + if len(v.shape) == 2: + U, S, Vt = torch.linalg.svd(v.to(torch.float32)) + print(f" -- {k}: {S[:5]}") + + print() + print("-- run with onnx_inputs --") + print() + sess = OnnxruntimeEvaluator(model, whole=True) feeds = onnx_inputs obtained = sess.run(None, feeds) print(f"-- obtained={string_type(obtained, **skws)}") diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01]) print(f"-- diff: {string_diff(diff)}") - print("--") - print("-- run with torch_inputs") - obtained = sess.run(None, ep_feeds) - print(f"-- obtained={string_type(obtained, **skws)}") - diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01]) - print(f"-- diff: {string_diff(diff)}") + print() + print("-- plots --") - print("-- plots") for i in range(len(expected)): study_discrepancies( expected[i], @@ -317,6 +328,19 @@ def get_replay_code(self) -> str: name=f"disc{i}.png", bins=50, ) + + print() + print("-- run with torch_inputs --") + print() + + obtained = sess.run(None, ep_feeds) + print(f"-- obtained={string_type(obtained, **skws)}") + diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01]) + print(f"-- diff: {string_diff(diff)}") + + print() + print("-- end --") + print() """ ) From aceb30605a79895e9e24b3128d12a3f6f3d1b462 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 25 Nov 2025 09:23:48 +0100 Subject: [PATCH 11/11] study --- _unittests/ut_helpers/test_torch_helper.py | 8 +++++++- onnx_diagnostic/helpers/torch_helper.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_helpers/test_torch_helper.py b/_unittests/ut_helpers/test_torch_helper.py index 9441e425..2c3c2990 100644 --- a/_unittests/ut_helpers/test_torch_helper.py +++ b/_unittests/ut_helpers/test_torch_helper.py @@ -28,7 +28,7 @@ ) from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended -from onnx_diagnostic.helpers.torch_helper import to_tensor +from onnx_diagnostic.helpers.torch_helper import to_tensor, study_discrepancies TFLOAT = onnx.TensorProto.FLOAT @@ -425,6 +425,12 @@ def test_get_weight_type(self): dt = get_weight_type(model) self.assertEqual(torch.float32, dt) + def test_study_discrepancies(self): + t1 = torch.rand((3, 4)) + t2 = torch.rand((3, 4)) + ax = study_discrepancies(t1, t2) + self.assertEqual(ax.shape, ((3, 2))) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index b79288ba..b0d5cae6 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -1024,7 +1024,7 @@ def study_discrepancies( figsize: Optional[Tuple[int, int]] = (15, 15), title: Optional[str] = None, name: Optional[str] = None, -) -> "Axes": # noqa: F821 +) -> "matplotlib.axes.Axes": # noqa: F821 """ Computes different metrics for the discrepancies. Returns graphs.