From dc061aebb9fd2ac8a0d2451362fbba5e05e16585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 12:40:41 +0100 Subject: [PATCH 01/22] improves side by side --- .gitignore | 2 + _unittests/ut_torch_onnx/test_sbs.py | 78 ++++-- onnx_diagnostic/helpers/torch_helper.py | 25 +- onnx_diagnostic/torch_onnx/sbs.py | 311 +++++++++++++++++------- 4 files changed, 295 insertions(+), 121 deletions(-) diff --git a/.gitignore b/.gitignore index 2ba9c1e6..900e2ae8 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,8 @@ *.weight *.nsys-rep *.pkl +*.pt +*.pt2 *.xlsx *.sarif *.sqlitest diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 63732455..7b9020fa 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -6,6 +6,7 @@ ignore_errors, ) from onnx_diagnostic.reference import ExtendedReferenceEvaluator +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_onnx.sbs import run_aligned try: @@ -15,15 +16,18 @@ class TestSideBySide(ExtTestCase): + @classmethod + def setUpClass(cls): + import torch + + cls.torch = torch @hide_stdout() @unittest.skipIf(to_onnx is None, "to_onnx not installed") @ignore_errors(OSError) # connectivity issues @ignore_warnings((UserWarning,)) def test_ep_onnx_sync_exp(self): - import torch - - class Model(torch.nn.Module): + class Model(self.torch.nn.Module): def forward(self, x): ry = x.abs() rz = ry.exp() @@ -31,20 +35,20 @@ def forward(self, x): ru = rw.log() + rw return ru - x = torch.randn((5, 4)) + x = self.torch.randn((5, 4)) Model()(x) - ep = torch.export.export( - Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) + ep = self.torch.export.export( + Model(), (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},) ) onx = to_onnx(ep) results = list( run_aligned( ep, onx, - (x,), - check_conversion_cls=dict( - cls=ExtendedReferenceEvaluator, atol=1e-5, rtol=1e-5 - ), + args=(x,), + run_cls=ExtendedReferenceEvaluator, + atol=1e-5, + rtol=1e-5, verbose=1, ), ) @@ -53,9 +57,7 @@ def forward(self, x): @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) def test_ep_onnx_sync_a(self): - import torch - - class Model(torch.nn.Module): + class Model(self.torch.nn.Module): def forward(self, x): ry = x.abs() rz = ry.exp() @@ -63,21 +65,55 @@ def forward(self, x): ru = rw.log() + rw return ru - x = torch.randn((5, 4)) + x = self.torch.randn((5, 4)) Model()(x) - ep = torch.export.export( - Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) + ep = self.torch.export.export( + Model(), (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},) + ) + epo = self.torch.onnx.export( + ep, (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},) + ) + onx = epo.model_proto + results = list( + run_aligned( + ep, + onx, + args=(x,), + run_cls=ExtendedReferenceEvaluator, + atol=1e-5, + rtol=1e-5, + verbose=1, + ), + ) + self.assertEqual(len(results), 4) + + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_dict(self): + class Model(self.torch.nn.Module): + def forward(self, x): + ry = x.abs() + rz = ry.exp() + rw = rz + 1 + ru = rw.log() + rw + return ru + + inputs = dict(x=self.torch.randn((5, 4))) + ds = dict(x={0: "batch"}) + Model()(**inputs) + ep = self.torch.export.export( + Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - epo = torch.onnx.export(ep, (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)) + epo = self.torch.onnx.export(Model(), (), kwargs=inputs, dynamic_shapes=ds) onx = epo.model_proto results = list( run_aligned( ep, onx, - (x,), - check_conversion_cls=dict( - cls=ExtendedReferenceEvaluator, atol=1e-4, rtol=1e-4 - ), + kwargs=inputs, + run_cls=ExtendedReferenceEvaluator, + atol=1e-5, + rtol=1e-5, verbose=1, ), ) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index e86dbef6..ec220b29 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -30,9 +30,7 @@ def proto_from_tensor( - arr: "torch.Tensor", # noqa: F821 - name: Optional[str] = None, - verbose: int = 0, + arr: torch.Tensor, name: Optional[str] = None, verbose: int = 0 ) -> onnx.TensorProto: """ Converts a torch Tensor into a TensorProto. @@ -98,7 +96,7 @@ def proto_from_tensor( return tensor -def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821 +def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype: """ Converts an onnx type into a torch dtype. @@ -140,7 +138,7 @@ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821 ) -def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821 +def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int: """ Converts a torch dtype into a onnx element type. @@ -483,7 +481,7 @@ def is_torchdynamo_exporting() -> bool: return False -def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821 +def to_numpy(tensor: torch.Tensor) -> np.ndarray: """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`.""" try: return tensor.detach().cpu().numpy() @@ -498,6 +496,21 @@ def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821 return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype]) +def from_numpy(tensor: np.ndarray) -> torch.Tensor: + """Converts a :class:`numpy.ndarray` to :class:`torch.Tensor`.""" + try: + return torch.from_numpy(tensor) + except TypeError: + # We try with ml_dtypes + pass + + import ml_dtypes + + conv = {ml_dtypes.bfloat16: torch.bfloat16} + assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}" + return torch.from_numpy(tensor.astype(torch.float32)).to(conv[tensor.dtype]) + + def replace_string_by_dynamic(dynamic_shapes: Any) -> Any: """Replaces strings by ``torch.export.Dim.DYNAMIC``.""" import torch diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 01e36080..ea6ee874 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import onnx +import onnx.helper as oh +import numpy as np import torch -from ..helpers import string_type, string_diff, max_diff -from ..helpers.onnx_helper import to_array_extended -from ..helpers.torch_helper import to_numpy +from ..helpers import string_type, string_diff, max_diff, flatten_object +from ..helpers.torch_helper import to_numpy, from_numpy def validate_fx_tensor( @@ -148,9 +149,22 @@ def prepare_args_kwargs( def run_aligned( ep: torch.export.ExportedProgram, onx: Union[onnx.ModelProto, onnx.FunctionProto], - args: Tuple[torch.Tensor, ...], - check_conversion_cls: Union[Dict[str, Any], type], + run_cls: Callable[ + [ + Union[ + onnx.ModelProto, + onnx.FunctionProto, + onnx.GraphProto, + onnx.NodeProto, + ] + ], + List[Union[np.ndarray, torch.Tensor]], + ], + args: Optional[Tuple[torch.Tensor, ...]] = None, kwargs: Optional[Dict[str, Any]] = None, + use_tensor: bool = False, + atol: Optional[float] = None, + rtol: Optional[float] = None, verbose: int = 0, ) -> Iterator[Tuple[Any, ...]]: """ @@ -162,11 +176,14 @@ def run_aligned( :param ep: exported program :param onx: model or function proto + :param run_cls: defines the runtime to use for this task :param args: input args - :param check_conversion_cls: defines the runtime to use for this task :param kwargs: input kwargs + :param use_tensor: use torch tensors instead of numpy arrays + :param atol: absolute tolerance + :param rtol: relative tolerance :param verbose: verbosity level - :return: a list of tuples containing the results, they come in tuple, + :return: a list of tuples containing the results, they come in tuple Example: @@ -174,7 +191,6 @@ def run_aligned( :showcode: :warningout: UserWarning - import pprint import pandas import torch from onnx_diagnostic.reference import ( @@ -212,11 +228,89 @@ def post_process(obs): map( post_process, run_aligned( - ep, - onx, - (x,), - check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5), - verbose=1, + ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1 + ), + ), + ) + print("------------") + print("final results") + df = pandas.DataFrame(results) + print(df) + + + This example uses :class:`onnx.reference.ReferenceEvaluator` to run the onnx model + but onnxruntime can also be used through + :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`. + It relies on :epkg:`onnxruntime` and selects CPU or CUDA depending + on the device where the inputs are located. + + The :class:`torch.export.ExportedProgram` can be saved on disk + with ``ep.save(".pt")`` and restored with + ``torch.export.load(".pt")``. That leeds the input to save. + We can decouple the export and the alignment. + + .. runpython:: + :showcode: + :warningout: UserWarning + + import onnx + import torch + from onnx_diagnostic.torch_export_patches.patch_inputs.use_dyn_not_str + + + class Model(torch.nn.Module): + def forward(self, x): + ry = x.abs() + rz = ry.exp() + rw = rz + 1 + ru = rw.log() + rw + return ru + + + x = torch.randn((5, 4)) + dynamic_shapes = ({0: "batch"},) + Model()(x) # to make sure the model is running + ep = torch.export.export( + Model(), (x,), dynamic_shapes=use_dyn_not_str(dynamic_shapes) + ) + onx = torch.onnx.export( + Model(), (x,), dynamic_shapes=dynamic_shapes + ).model_proto + + torch.export.save(ep, "test_doc_sbs_example.pt2") + onnx.save(onx, "test_doc_sbs_example.onnx") + torch.save((x,), "test_doc_sbs_example.pt") + + Then we can restore all of them and run it. + + .. runpython:: + :showcode: + :warningout: UserWarning + + import pandas + import onnx + import torch + from onnx_diagnostic.torch_onnx.sbs import run_aligned + from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch + + + ep = torch.export.load("test_doc_sbs_example.pt2") + onx = onnx.load("test_doc_sbs_example.onnx") + inputs = torch.load("test_doc_sbs_example.pt") + + + def post_process(obs): + dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs)) + dobs["err_abs"] = obs[-1]["abs"] + dobs["err_rel"] = obs[-1]["rel"] + return dobs + + + results = list( + map( + post_process, + run_aligned( + ep, onx, InferenceSessionForTorch, inputs, atol=1e-5, rtol=1e-5, verbose=1 ), ), ) @@ -225,18 +319,56 @@ def post_process(obs): df = pandas.DataFrame(results) print(df) """ - assert not kwargs, f"Not implemented when kwargs={string_type(kwargs,with_shape=True)}" - cls, atol, rtol = ( - ( - check_conversion_cls["cls"], - check_conversion_cls["atol"], - check_conversion_cls["rtol"], + assert callable(run_cls), f"run_cls={run_cls} not a callable" + + def _check_tensor_(name, obj): + assert not use_tensor or isinstance(obj, torch.Tensor), ( + f"Unexpected type {type(obj)} for {name!r}. " + f"use_tensor is True so torch.Tensor is expected." ) - if isinstance(check_conversion_cls, dict) - else (check_conversion_cls, None, None) - ) + assert use_tensor or isinstance(obj, np.ndarray), ( + f"Unexpected type {type(obj)} for {name!r}. " + f"use_tensor is False so np.array is expected." + ) + return obj + + def _make_node_from_initializer(proto: onnx.TensorProto) -> onnx.NodeProto: + return oh.make_node("Constant", [], [proto.name], value=proto) - # retrieve the positions + def _loop_cmp( + mapping_onnx_to_torch, torch_results, onnx_results, o, r, verbose, atol, rtol + ): + onnx_results[o] = _check_tensor_(o, r) + if verbose: + print( + f"[run_aligned-nx] +res: {o}=" + f"{string_type(r, with_shape=True, with_min_max=True)}" + ) + + to = mapping_onnx_to_torch.get(o, o) + if to in torch_results: + d = max_diff(torch_results[to], r) + if verbose: + if o == to: + print(f"[run_aligned-==] cmp {to}: {string_diff(d)}") + else: + print(f"[run_aligned-~~] cmd {to}/{o}: {string_diff(d)}") + if not ( + atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol) + ): + skw = dict(with_shape=True, with_min_max=True) + raise ValueError( + f"discrepancies detected for results [{to}/{o}]: " + f"{string_diff(d)}" + f"\n-- torch_results: {string_type(torch_results[to], **skw)}" + f"\n-- onnx_results: {string_type(r, **skw)}" + f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" + ) + return (i, i_onnx, o, to, d) + return None + + if verbose: + print(f"[run_aligned] walks through {len(ep.graph.nodes)} nodes from torch") positions: Dict[str, Any] = {} for i, node in enumerate(ep.graph.nodes): if isinstance(node.name, str): @@ -245,6 +377,8 @@ def post_process(obs): for n in node.name: positions[n] = dict(fx=i) + if verbose: + print(f"[run_aligned] walks through {len(onx.graph.node)} nodes from onnx") for i, node in enumerate(onx.graph.node): for n in node.output: if n in positions: @@ -252,10 +386,14 @@ def post_process(obs): else: positions[n] = dict(onnx=i) + if verbose: + print(f"[run_aligned] handles {len(onx.graph.initializer)} initializers from onnx") onnx_results: Dict[str, Any] = {} for init in onx.graph.initializer: # type: ignore positions[init.name] = -1 - onnx_results[init.name] = to_array_extended(init) + onnx_results[init.name] = _check_tensor_( + init.name, run_cls(_make_node_from_initializer(init)).run(None, {})[0] + ) param_name = f"p_{init.name.replace('.', '_')}" if param_name == init.name: continue @@ -265,11 +403,17 @@ def post_process(obs): ) onnx_results[param_name] = onnx_results[init.name] + if verbose: + print(f"[run_aligned] handles common {len(onnx_results)} initializer from torch") + # we should be careful, torch may modified inplace the weights, + # it may be difficult to share weights torch_results: Dict[str, Any] = { - k: torch.from_numpy(v.copy()) - for k, v in onnx_results.items() - if not k.startswith("init") + k: from_numpy(v) for k, v in onnx_results.items() if not k.startswith("init") } + if verbose: + print( + f"[run_aligned] handles other constant from {len(ep.graph.nodes)} nodes from torch" + ) last_position = 0 torch_output_names = None for node in ep.graph.nodes: @@ -285,22 +429,33 @@ def post_process(obs): mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names)) if verbose: + print(f"[run_aligned] torch {len(torch_results)} constants") + print(f"[run_aligned] onnx {len(onnx_results)} constants") + print(f"[run_aligned] common {len(mapping_onnx_to_torch)} constants") for k, v in torch_results.items(): print( - f"[run_aligned] +torch-cst: {k}: " + f"[run_aligned-ep] +cst: {k}: " f"{string_type(v, with_shape=True, with_min_max=True)}" ) for k, v in onnx_results.items(): print( - f"[run_aligned] +onnx-init: {k}: " + f"[run_aligned-nx] +ini: {k}: " f"{string_type(v, with_shape=True, with_min_max=True)}" ) - for inp, v in zip(onx.graph.input, args): - onnx_results[inp.name] = to_numpy(v) + onnx_args = list(args) if args else [] + if kwargs: + onnx_args.extend(flatten_object(kwargs, drop_keys=True)) + if verbose: + print(f"[run_aligned] args: {string_type(args, with_shape=True)}") + print(f"[run_aligned] kwargs: {string_type(kwargs, with_shape=True)}") + print(f"[run_aligned] onnx: {string_type(onnx_args, with_shape=True)}") + print(f"[run_aligned] walks through {len(onx.graph.input)} onnx inputs") + for inp, v in zip(onx.graph.input, onnx_args): + onnx_results[inp.name] = _check_tensor_(inp.name, v if use_tensor else to_numpy(v)) if verbose: print( - f"[run_aligned] +onnx-input: {inp.name}: " + f"[run_aligned-nx] +inp: {inp.name}: " f"{string_type(v, with_shape=True, with_min_max=True)}" ) @@ -316,17 +471,23 @@ def post_process(obs): if node.op == "placeholder": if node.name in onnx_results: - torch_results[node.name] = torch.from_numpy(onnx_results[node.name].copy()) + torch_results[node.name] = ( + onnx_results[node.name] + if use_tensor + else torch.from_numpy(onnx_results[node.name]) + ) if verbose: t = torch_results[node.name] print( - f"[run_aligned] +torch {node.name}=" + f"[run_aligned-ep] +plh: {node.name}=" f"{string_type(t, with_shape=True, with_min_max=True)}" ) continue raise AssertionError( f"unable to process node {node.op} -> {node.name!r} " - f"not in {sorted(onnx_results)}, len(args)={len(args)}, " + f"not in {sorted(onnx_results)}, " + f"args={string_type(args, with_shape=True)}, " + f"kwargs={string_type(kwargs, with_shape=True)}, " f"onx.graph.input={[i.name for i in onx.graph.input]}" ) @@ -345,7 +506,7 @@ def post_process(obs): if verbose: for k, v in zip(outputs, new_outputs): print( - f"[run_aligned] +torch {k}=" + f"[run_aligned-ep] +res: {k}=" f"{string_type(v, with_shape=True, with_min_max=True)}" ) @@ -364,39 +525,22 @@ def post_process(obs): f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" ) - ref = cls(node) + ref = run_cls(node) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) for o, r in zip(node.output, res): - onnx_results[o] = r - if verbose: - print( - f"[run_aligned] +onnx {o}=" - f"{string_type(r, with_shape=True, with_min_max=True)}" - ) - - to = mapping_onnx_to_torch.get(o, o) - if to in torch_results: - d = max_diff(torch_results[to], r) - if verbose: - if o == to: - print(f"[run_aligned] =common results {to}: {string_diff(d)}") - else: - print(f"[run_aligned] =common results {to}/{o}: {string_diff(d)}") - if not ( - atol is None - or rtol is None - or (d["abs"] <= atol and d["rel"] <= rtol) - ): - skw = dict(with_shape=True, with_min_max=True) - raise ValueError( - f"discrepancies detected for results [{to}/{o}]: " - f"{string_diff(d)}" - f"\n-- torch_results: {string_type(torch_results[to], **skw)}" - f"\n-- onnx_results: {string_type(r, **skw)}" - f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" - ) - yield (i, i_onnx, o, to, d) + tmp = _loop_cmp( + mapping_onnx_to_torch, + torch_results, + onnx_results, + o, + r, + verbose, + atol, + rtol, + ) + if tmp is not None: + yield tmp last_position = max_pos + 1 @@ -408,33 +552,12 @@ def post_process(obs): f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" ) - ref = cls(node) + ref = run_cls(node) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) for o, r in zip(node.output, res): - onnx_results[o] = r - if verbose: - print( - f"[run_aligned] +onnx {o}=" - f"{string_type(r, with_shape=True, with_min_max=True)}" - ) - - to = mapping_onnx_to_torch.get(o, o) - if to in torch_results: - d = max_diff(torch_results[to], r) - if verbose: - if o == to: - print(f"[run_aligned] =common results* {to}: {string_diff(d)}") - else: - print(f"[run_aligned] =common results* {to}/{o}: {string_diff(d)}") - if not ( - atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol) - ): - skw = dict(with_shape=True, with_min_max=True) - raise ValueError( - f"discrepancies detected for results* [{to}/{o}]: {string_diff(d)}" - f"\n-- torch_results: {string_type(torch_results[to], **skw)}" - f"\n-- onnx_results: {string_type(r, **skw)}" - f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" - ) - yield (i, i_onnx, o, to, d) + tmp = _loop_cmp( + mapping_onnx_to_torch, torch_results, onnx_results, o, r, verbose, atol, rtol + ) + if tmp is not None: + yield tmp From 1ccdb79e56aead8881ac911767bc15d5275c2d71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 12:42:36 +0100 Subject: [PATCH 02/22] mypy --- onnx_diagnostic/torch_onnx/sbs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index ea6ee874..765dd215 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -392,7 +392,7 @@ def _loop_cmp( for init in onx.graph.initializer: # type: ignore positions[init.name] = -1 onnx_results[init.name] = _check_tensor_( - init.name, run_cls(_make_node_from_initializer(init)).run(None, {})[0] + init.name, run_cls(_make_node_from_initializer(init)).run(None, {})[0] # type: ignore[attr-defined] ) param_name = f"p_{init.name.replace('.', '_')}" if param_name == init.name: @@ -527,7 +527,7 @@ def _loop_cmp( ) ref = run_cls(node) feeds = {k: onnx_results[k] for k in node.input} - res = ref.run(None, feeds) + res = ref.run(None, feeds) # type: ignore[attr-defined] for o, r in zip(node.output, res): tmp = _loop_cmp( mapping_onnx_to_torch, @@ -554,7 +554,7 @@ def _loop_cmp( ) ref = run_cls(node) feeds = {k: onnx_results[k] for k in node.input} - res = ref.run(None, feeds) + res = ref.run(None, feeds) # type: ignore[attr-defined] for o, r in zip(node.output, res): tmp = _loop_cmp( mapping_onnx_to_torch, torch_results, onnx_results, o, r, verbose, atol, rtol From 294bc7876aa6623532e1b205f4add510dac32d64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 13:48:24 +0100 Subject: [PATCH 03/22] changes --- CHANGELOGS.rst | 1 + _unittests/ut_torch_onnx/test_sbs.py | 71 ++++++++++++++++++++++++++-- onnx_diagnostic/ext_test_case.py | 2 +- onnx_diagnostic/torch_onnx/sbs.py | 52 +++++++++++++++++--- 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 73d274b6..70d18e8a 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.2 +++++ +* :pr:`304`: improves side-by-side comparison * :pr:`303`: fix inputs for summarization, feature extraction tasks * :pr:`302`: adds helpers to analyse onnxruntime profiling * :pr:`297`: experiment around a higher ops ``loop_for`` diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 7b9020fa..22841865 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -5,7 +5,7 @@ ignore_warnings, ignore_errors, ) -from onnx_diagnostic.reference import ExtendedReferenceEvaluator +from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_onnx.sbs import run_aligned @@ -71,7 +71,7 @@ def forward(self, x): Model(), (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},) ) epo = self.torch.onnx.export( - ep, (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},) + ep, (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},), dynamo=True ) onx = epo.model_proto results = list( @@ -104,7 +104,9 @@ def forward(self, x): ep = self.torch.export.export( Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - epo = self.torch.onnx.export(Model(), (), kwargs=inputs, dynamic_shapes=ds) + epo = self.torch.onnx.export( + Model(), (), kwargs=inputs, dynamic_shapes=ds, dynamo=True + ) onx = epo.model_proto results = list( run_aligned( @@ -119,6 +121,69 @@ def forward(self, x): ) self.assertEqual(len(results), 4) + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_dict_onnxruntime(self): + class Model(self.torch.nn.Module): + def forward(self, x): + ry = x.abs() + rz = ry.exp() + rw = rz + 1 + ru = rw.log() + rw + return ru + + inputs = dict(x=self.torch.randn((5, 4))) + ds = dict(x={0: "batch"}) + Model()(**inputs) + ep = self.torch.export.export( + Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + onx = to_onnx(ep) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + atol=1e-5, + rtol=1e-5, + verbose=11, + ), + ) + self.assertEqual(len(results), 5) + + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_dict_tensor(self): + class Model(self.torch.nn.Module): + def forward(self, x): + ry = x.abs() + rz = ry.exp() + rw = rz + 1 + ru = rw.log() + rw + return ru + + inputs = dict(x=self.torch.randn((5, 4))) + ds = dict(x={0: "batch"}) + Model()(**inputs) + ep = self.torch.export.export( + Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + onx = to_onnx(ep) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + atol=1e-5, + rtol=1e-5, + verbose=11, + use_tensor=True, + ), + ) + self.assertEqual(len(results), 5) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index efb8e633..7afa3108 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -147,7 +147,7 @@ def hide_stdout(f: Optional[Callable] = None) -> Callable: def wrapper(fct): def call_f(self): - if os.environ.get("UNHIDE", ""): + if os.environ.get("UNHIDE", "") in (1, "1", "True", "true"): fct(self) return st = StringIO() diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 765dd215..362a5810 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import onnx import onnx.helper as oh @@ -291,7 +292,7 @@ def forward(self, x): import onnx import torch from onnx_diagnostic.torch_onnx.sbs import run_aligned - from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch + from onnx_diagnostic.reference import OnnxruntimeEvaluator ep = torch.export.load("test_doc_sbs_example.pt2") @@ -310,7 +311,14 @@ def post_process(obs): map( post_process, run_aligned( - ep, onx, InferenceSessionForTorch, inputs, atol=1e-5, rtol=1e-5, verbose=1 + ep, + onx, + OnnxruntimeEvaluator, + inputs, + atol=1e-5, + rtol=1e-5, + verbose=1, + use_tensor=True, ), ), ) @@ -320,8 +328,29 @@ def post_process(obs): print(df) """ assert callable(run_cls), f"run_cls={run_cls} not a callable" + run_cls_kwargs = { + "ir_version": onx.ir_version, + "opsets": {d.domain: d.version for d in onx.opset_import}, + "verbose": max(verbose - 1, 0), + } + run_cls_kwargs = { + k: v + for k, v in run_cls_kwargs.items() + if k in set(inspect.signature(run_cls).parameters) + } + if verbose: + print(f"[run_aligned] run_cls={run_cls}") + print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}") + + def _check_tensor_(name, obj, flip_type=False): + if flip_type: + if use_tensor: + if isinstance(obj, np.ndarray): + obj = from_numpy(obj) + else: + if isinstance(obj, torch.Tensor): + obj = to_numpy(obj) - def _check_tensor_(name, obj): assert not use_tensor or isinstance(obj, torch.Tensor), ( f"Unexpected type {type(obj)} for {name!r}. " f"use_tensor is True so torch.Tensor is expected." @@ -392,7 +421,14 @@ def _loop_cmp( for init in onx.graph.initializer: # type: ignore positions[init.name] = -1 onnx_results[init.name] = _check_tensor_( - init.name, run_cls(_make_node_from_initializer(init)).run(None, {})[0] # type: ignore[attr-defined] + init.name, + run_cls( + _make_node_from_initializer(init), + **run_cls_kwargs, + ).run(None, {})[ + 0 + ], # type: ignore[attr-defined] + flip_type=True, ) param_name = f"p_{init.name.replace('.', '_')}" if param_name == init.name: @@ -408,7 +444,9 @@ def _loop_cmp( # we should be careful, torch may modified inplace the weights, # it may be difficult to share weights torch_results: Dict[str, Any] = { - k: from_numpy(v) for k, v in onnx_results.items() if not k.startswith("init") + k: (v if use_tensor else from_numpy(v)) + for k, v in onnx_results.items() + if not k.startswith("init") } if verbose: print( @@ -525,7 +563,7 @@ def _loop_cmp( f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" ) - ref = run_cls(node) + ref = run_cls(node, **run_cls_kwargs) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) # type: ignore[attr-defined] for o, r in zip(node.output, res): @@ -552,7 +590,7 @@ def _loop_cmp( f"[run_aligned] run onx.graph.node[{i_onnx}]: " f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}" ) - ref = run_cls(node) + ref = run_cls(node, **run_cls_kwargs) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) # type: ignore[attr-defined] for o, r in zip(node.output, res): From 26b47a8fe22db600dd4169c302f074cf040def90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 13:51:39 +0100 Subject: [PATCH 04/22] mypy --- onnx_diagnostic/torch_onnx/sbs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 362a5810..c66c6ffa 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -425,9 +425,9 @@ def _loop_cmp( run_cls( _make_node_from_initializer(init), **run_cls_kwargs, - ).run(None, {})[ - 0 - ], # type: ignore[attr-defined] + ).run( # type: ignore[attr-defined] + None, {} + )[0], flip_type=True, ) param_name = f"p_{init.name.replace('.', '_')}" From 49722bc1d9026986ab5f4804b61585f5905344c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 16:24:19 +0100 Subject: [PATCH 05/22] fixes --- _unittests/ut_export/test_api.py | 2 + _unittests/ut_helpers/test_helper.py | 26 ++++ _unittests/ut_helpers/test_log_helper.py | 5 + _unittests/ut_helpers/test_ort_session.py | 1 + _unittests/ut_helpers/test_torch_helper.py | 4 + .../test_patch_transformers.py | 2 + .../test_hghub_mode_rewrite.py | 4 +- .../ut_torch_models/test_validate_models.py | 3 + .../test_validate_whole_models1.py | 8 ++ .../test_validate_whole_models2.py | 1 + .../test_validate_whole_models3.py | 2 + _unittests/ut_torch_onnx/test_sbs.py | 111 +++++++++++++++++- .../ut_xrun_doc/test_check_ort_float16.py | 1 + onnx_diagnostic/ext_test_case.py | 10 ++ onnx_diagnostic/helpers/helper.py | 82 +++++++++++-- onnx_diagnostic/helpers/ort_session.py | 31 ++++- onnx_diagnostic/reference/ort_evaluator.py | 20 +++- onnx_diagnostic/torch_onnx/sbs.py | 102 +++++++++------- 18 files changed, 348 insertions(+), 67 deletions(-) diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index 76078acd..c1fa8b09 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -110,6 +110,8 @@ def test_tiny_llm_to_onnx(self): diff = max_diff(expected, got) assert diff["abs"] <= 1e-5, f"diff={diff}" + self.clean_dump() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_helpers/test_helper.py b/_unittests/ut_helpers/test_helper.py index c9890e98..29a279d7 100644 --- a/_unittests/ut_helpers/test_helper.py +++ b/_unittests/ut_helpers/test_helper.py @@ -11,6 +11,7 @@ hide_stdout, requires_onnx, requires_transformers, + requires_cuda, ) from onnx_diagnostic.helpers.helper import ( string_type, @@ -199,6 +200,31 @@ def test_flatten(self): d = string_diff(diff) self.assertIsInstance(d, str) + @hide_stdout() + def test_maxdiff_device(self): + inputs = (torch.arange(2), torch.cos(torch.arange(3))) + diff = max_diff(inputs, inputs, flatten=True, verbose=10) + self.assertEqual(diff["abs"], 0) + self.assertEqual(diff["dev"], 0) + + @hide_stdout() + @requires_cuda() + def test_maxdiff_device_cuda(self): + diff = max_diff(torch.ones((2,)).cuda(), torch.ones((2,)), verbose=10) + self.assertEqual(diff["dev"], 1) + inputs = (torch.arange(2), torch.cos(torch.arange(3))) + inputs2 = (inputs[0].cuda(), inputs[1].cuda()) + diff = max_diff(inputs, inputs2, verbose=10) + self.assertEqual(diff["abs"], 0) + self.assertEqual(diff["dev"], 2) + inputs2 = (inputs[0], inputs[1].cuda()) + diff = max_diff(inputs, inputs2, verbose=10) + self.assertEqual(diff["abs"], 0) + self.assertEqual(diff["dev"], 1) + diff = max_diff(inputs2, inputs2, verbose=10) + self.assertEqual(diff["abs"], 0) + self.assertEqual(diff["dev"], 0) + def test_flatten_cache(self): cache = make_dynamic_cache([(torch.ones((5, 6, 5, 6)), torch.ones((5, 6, 5, 6)) + 2)]) flat = flatten_object(cache, drop_keys=True) diff --git a/_unittests/ut_helpers/test_log_helper.py b/_unittests/ut_helpers/test_log_helper.py index 96fc88d7..df8af469 100644 --- a/_unittests/ut_helpers/test_log_helper.py +++ b/_unittests/ut_helpers/test_log_helper.py @@ -189,6 +189,7 @@ def test_cube_logs_excel(self): verbose=1, ) self.assertExists(output) + self.clean_dump() @hide_stdout() def test_enumerate_csv_files(self): @@ -210,6 +211,7 @@ def test_enumerate_csv_files(self): cube.load(verbose=1) self.assertEqual((3, 11), cube.shape) self.assertIn("RAWFILENAME", cube.data.columns) + self.clean_dump() def test_cube_logs_performance1(self): output = self.get_dump_file("test_cube_logs_performance1.xlsx") @@ -235,6 +237,7 @@ def test_cube_logs_performance1(self): ], ) self.assertExists(output) + self.clean_dump() def test_cube_logs_performance2(self): output = self.get_dump_file("test_cube_logs_performance2.xlsx") @@ -470,6 +473,7 @@ def test_historical_cube_time_mask(self): ) cube = CubeLogs(df, keys=["^m_*", "exporter"], time="date").load() cube.to_excel(output, views=["time_p"], time_mask=True, verbose=1) + self.clean_dump() def test_cube_sbs_no_time(self): df = pandas.DataFrame( @@ -532,6 +536,7 @@ def test_cube_sbs_no_time(self): verbose=0, sbs=dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O")), ) + self.clean_dump() def test_cube_sbs_with_time(self): df = pandas.DataFrame( diff --git a/_unittests/ut_helpers/test_ort_session.py b/_unittests/ut_helpers/test_ort_session.py index 0ec5af44..87d5f3df 100644 --- a/_unittests/ut_helpers/test_ort_session.py +++ b/_unittests/ut_helpers/test_ort_session.py @@ -310,6 +310,7 @@ def test_profiling(self): got = wrap.run(None, feeds) self.assertIsInstance(got[0], torch.Tensor) self.assertEqualArray(expected[0], got[0]) + self.clean_dump() if __name__ == "__main__": diff --git a/_unittests/ut_helpers/test_torch_helper.py b/_unittests/ut_helpers/test_torch_helper.py index bc0a29b8..9441e425 100644 --- a/_unittests/ut_helpers/test_torch_helper.py +++ b/_unittests/ut_helpers/test_torch_helper.py @@ -151,6 +151,7 @@ def forward(self, x, y): self.assertEqualAny(restored["main", 1, "I"], (inputs, {})) self.assertEqualAny(restored["main", 0, "O"], res1) self.assertEqualAny(restored["main", 0, "O"], res2) + self.clean_dump() @hide_stdout() def test_steal_forward_dump_file_steal_append(self): @@ -181,6 +182,7 @@ def forward(self, x, y): {("", 1, "I"), ("", 1, "O"), "sx", ("", 0, "O"), "sx_1", ("", 0, "I")}, set(restored), ) + self.clean_dump() @hide_stdout() def test_steal_forward_dump_file_steal_append_drop(self): @@ -214,6 +216,7 @@ def forward(self, x, y): first = restored[("", 0, "I")] _a, kws = first self.assertNotIn("x", kws) + self.clean_dump() @hide_stdout() def test_steal_forward_submodules(self): @@ -257,6 +260,7 @@ def forward(self, x, y): ), len(sorted(restored)), ) + self.clean_dump() def test_replace_string_by_dynamic(self): example = { diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 5bc3d0ba..bf96ffcd 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -464,6 +464,7 @@ def forward( atol=1e-3, rtol=1, ) + self.clean_dump() @requires_transformers("4.99") @requires_torch("2.9.99") @@ -508,6 +509,7 @@ def test_qwen2_5_vl_vision_attention_iteration(self): atol=1e-3, rtol=1, ) + self.clean_dump() if __name__ == "__main__": diff --git a/_unittests/ut_torch_models/test_hghub_mode_rewrite.py b/_unittests/ut_torch_models/test_hghub_mode_rewrite.py index 1dcdca82..24907bb0 100644 --- a/_unittests/ut_torch_models/test_hghub_mode_rewrite.py +++ b/_unittests/ut_torch_models/test_hghub_mode_rewrite.py @@ -25,8 +25,8 @@ def test_export_rewriting_bart(self): data = get_untrained_model_with_inputs(mid, verbose=1) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] dump_folder = self.get_dump_file("test_export_rewritin_bart") - print(self.string_type(inputs)) - print(self.string_type(ds)) + print("--", self.string_type(inputs)) + print("--", self.string_type(ds)) with torch_export_patches( patch_transformers=True, rewrite=model, dump_rewriting=dump_folder ): diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index f2caa5e0..6ecc2653 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -43,6 +43,7 @@ def test_validate_tiny_llms_bfloat16(self): ) self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-2) self.assertIn("onnx_filename", data) + self.clean_dump() @unittest.skipIf(torch29_and_tr_main, "combination not working") @requires_transformers("4.53") @@ -65,6 +66,7 @@ def test_validate_microsoft_phi4_reasoning(self): ) self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-5) self.assertIn("onnx_filename", data) + self.clean_dump() @unittest.skipIf(torch29_and_tr_main, "combination not working") @requires_transformers("4.53") @@ -87,6 +89,7 @@ def test_validate_microsoft_phi3_mini_128k(self): ) self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-5) self.assertIn("onnx_filename", data) + self.clean_dump() if __name__ == "__main__": diff --git a/_unittests/ut_torch_models/test_validate_whole_models1.py b/_unittests/ut_torch_models/test_validate_whole_models1.py index dafa4297..d1924e08 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models1.py +++ b/_unittests/ut_torch_models/test_validate_whole_models1.py @@ -50,6 +50,7 @@ def test_c_validate_model(self): self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) validate_model(mid, do_run=True, verbose=2, quiet=True) + self.clean_dump() @hide_stdout() def test_d_validate_model_dtype(self): @@ -60,6 +61,7 @@ def test_d_validate_model_dtype(self): self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) validate_model(mid, do_run=True, verbose=2, quiet=True) + self.clean_dump() @hide_stdout() def test_e_validate_model_export(self): @@ -74,6 +76,7 @@ def test_e_validate_model_export(self): ) self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) + self.clean_dump() @requires_torch("2.10.99") @requires_transformers("4.54") @@ -99,6 +102,7 @@ def test_f_validate_model_onnx_dynamo_ir(self): run_ort_fusion( onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10 ) + self.clean_dump() @requires_torch("2.7") @requires_onnxscript("0.7") @@ -122,6 +126,7 @@ def test_g_validate_model_onnx_dynamo_os_ort(self): self.assertLess(summary["disc_onnx_ort_run2_batch1_abs"], 1e-4) onnx_filename = data["onnx_filename"] self.assertExists(onnx_filename) + self.clean_dump() @requires_torch("2.7") @hide_stdout() @@ -152,6 +157,7 @@ def test_i_validate_model_custom(self): run_ort_fusion( onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10 ) + self.clean_dump() @requires_torch("2.7") @hide_stdout() @@ -176,6 +182,7 @@ def test_j_validate_model_custom_torch(self): self.assertIsInstance(data, dict) self.assertIn("disc_onnx_ort_run_abs", summary) self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4) + self.clean_dump() def test_k_filter_inputs(self): inputs, ds = {"a": 1, "b": 2}, {"a": 20, "b": 30} @@ -222,6 +229,7 @@ def test_n_validate_phi35_mini_instruct(self): onx = onnx.load(onnx_filename) op_types = set(n.op_type for n in onx.graph.node) self.assertIn("If", op_types) + self.clean_dump() if __name__ == "__main__": diff --git a/_unittests/ut_torch_models/test_validate_whole_models2.py b/_unittests/ut_torch_models/test_validate_whole_models2.py index 3f0ad51a..bbedacd7 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models2.py +++ b/_unittests/ut_torch_models/test_validate_whole_models2.py @@ -41,6 +41,7 @@ def test_o_validate_phi35_4k_mini_instruct(self): onx = onnx.load(onnx_filename) op_types = set(n.op_type for n in onx.graph.node) self.assertIn("If", op_types) + self.clean_dump() if __name__ == "__main__": diff --git a/_unittests/ut_torch_models/test_validate_whole_models3.py b/_unittests/ut_torch_models/test_validate_whole_models3.py index 106cfadd..419dbe13 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models3.py +++ b/_unittests/ut_torch_models/test_validate_whole_models3.py @@ -34,6 +34,7 @@ def test_l_validate_model_modelbuilder(self): self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2) onnx_filename = data["onnx_filename"] self.assertExists(onnx_filename) + self.clean_dump() @requires_torch("2.7") @hide_stdout() @@ -59,6 +60,7 @@ def test_m_validate_model_vit_model(self): self.assertEqual("#1[A1s3x2]", summary["run_output_inputs2"]) onnx_filename = data["onnx_filename"] self.assertExists(onnx_filename) + self.clean_dump() if __name__ == "__main__": diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 22841865..be82aaa6 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -4,6 +4,7 @@ hide_stdout, ignore_warnings, ignore_errors, + requires_cuda, ) from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str @@ -160,7 +161,7 @@ def forward(self, x): ry = x.abs() rz = ry.exp() rw = rz + 1 - ru = rw.log() + rw + ru = rw.log() + rw + ry return ru inputs = dict(x=self.torch.randn((5, 4))) @@ -182,7 +183,113 @@ def forward(self, x): use_tensor=True, ), ) - self.assertEqual(len(results), 5) + self.assertEqual(len(results), 6) + self.clean_dump() + + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + @requires_cuda() + def test_sbs_dict_tensor_cuda(self): + class Model(self.torch.nn.Module): + def forward(self, x): + ry = x.abs() + rz = ry.exp() + rw = rz + 1 + ru = rw.log() + rw + ry + return ru + + inputs = dict(x=self.torch.randn((5, 4)).to("cuda")) + ds = dict(x={0: "batch"}) + Model()(**inputs) + ep = self.torch.export.export( + Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + onx = to_onnx(ep) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + atol=1e-5, + rtol=1e-5, + verbose=11, + use_tensor=True, + ), + ) + self.assertEqual(len(results), 6) + self.assertEqual([r[-1]["dev"] for r in results], [0, 0, 0, 0, 0, 0]) + + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + @requires_cuda() + def test_sbs_dict_tensor_cuda_reshape(self): + class Model(self.torch.nn.Module): + def forward(self, x): + ry = x.abs() + ry1 = ry.reshape((-1, 1)) + ry2 = ry.reshape((1, -1)) + prod = ry1 * ry2 + shape = prod.shape + resh = prod.reshape((-1, shape[0] // 2, shape[1] // 2)) + return resh.transpose(2, 1) + + inputs = dict(x=self.torch.randn((16, 16)).to("cuda")) + ds = dict(x={0: "batch"}) + Model()(**inputs) + ep = self.torch.export.export( + Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + onx = to_onnx(ep) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + atol=1e-5, + rtol=1e-5, + verbose=11, + use_tensor=True, + ), + ) + self.assertEqual(len(results), 7) + self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0]) + + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_dict_tensor_cpu_reshape(self): + class Model(self.torch.nn.Module): + def forward(self, x): + ry = x.abs() + ry1 = ry.reshape((-1, 1)) + ry2 = ry.reshape((1, -1)) + prod = ry1 * ry2 + shape = prod.shape + resh = prod.reshape((-1, shape[0] // 2, shape[1] // 2)) + return resh.transpose(2, 1) + + inputs = dict(x=self.torch.randn((16, 16))) + ds = dict(x={0: "batch"}) + Model()(**inputs) + ep = self.torch.export.export( + Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + onx = to_onnx(ep) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + atol=1e-5, + rtol=1e-5, + verbose=11, + use_tensor=True, + ), + ) + self.assertEqual(len(results), 7) + self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0]) if __name__ == "__main__": diff --git a/_unittests/ut_xrun_doc/test_check_ort_float16.py b/_unittests/ut_xrun_doc/test_check_ort_float16.py index 6adacc71..0f7fb120 100644 --- a/_unittests/ut_xrun_doc/test_check_ort_float16.py +++ b/_unittests/ut_xrun_doc/test_check_ort_float16.py @@ -141,6 +141,7 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names): short_list, tuple([("CUDAExecutionProvider", o) for o in en] for en in expected_names), ) + self.clean_dump() @unittest.skip("https://github.com/sdpython/onnx-diagnostic/issues/240") @requires_cuda() diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 7afa3108..f1cec362 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -9,6 +9,7 @@ import logging import os import re +import shutil import sys import unittest import warnings @@ -806,6 +807,15 @@ def get_dump_folder(self, folder: str) -> str: os.makedirs(folder) return folder + def clean_dump(self, folder: str = "dump_test"): + """Cleans this folder.""" + for item in os.listdir(folder): + item_path = os.path.join(folder, item) + if os.path.isfile(item_path) or os.path.islink(item_path): + os.remove(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + def dump_onnx( self, name: str, diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 092070b6..1c19840e 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -529,16 +529,20 @@ def string_type( return "OV(NO-NUMPY:FIXIT)" if verbose: print(f"[string_type] V4:{type(obj)}") - return f"OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})" + dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else "" + return ( + f"{dev}OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})" + ) dt = obj.element_type() shape = obj.shape() + dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else "" if with_shape: if verbose: print(f"[string_type] V5:{type(obj)}") - return f"OV{dt}s{'x'.join(map(str, shape))}" + return f"{dev}OV{dt}s{'x'.join(map(str, shape))}" if verbose: print(f"[string_type] V6:{type(obj)}") - return f"OV{dt}r{len(shape)}" + return f"{dev}OV{dt}r{len(shape)}" # others classes @@ -1015,6 +1019,7 @@ def max_diff( output, this number will be the number of elements of this output * dnan: difference in the number of nan + * dev: tensor on the same device, if applicable You may use :func:`string_diff` to display the discrepancies in one string. """ @@ -1167,7 +1172,7 @@ def max_diff( if verbose >= 6: print(f"[max_diff] list,tuple,6: {string_type(expected)} ? {string_type(got)}") - am, rm, sm, n, dn, drep = 0, 0, 0.0, 0.0, 0, None + am, rm, sm, n, dn, drep, dd = 0, 0, 0.0, 0.0, 0, None, None for ip, (e, g) in enumerate(zip(expected, got)): d = max_diff( e, @@ -1199,7 +1204,15 @@ def max_diff( else: for k, v in d["rep"].items(): drep[k] += v + if "dev" in d: + if dd is None: + dd = d["dev"] + else: + dd += d["dev"] + res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn) + if dd is not None: + res["dev"] = dd if drep: res["rep"] = drep return res # type: ignore @@ -1233,33 +1246,42 @@ def max_diff( import torch if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray): + dev = None if isinstance(expected, torch.Tensor): from .torch_helper import to_numpy + dev = 0 if expected.device.type == "cpu" else 1 expected = to_numpy(expected) if isinstance(got, torch.Tensor): from .torch_helper import to_numpy + dev = 0 if got.device.type == "cpu" else 1 got = to_numpy(got) if verbose >= 6: print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}") if _index < begin or (end != -1 and _index >= end): # out of boundary - return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) + res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) + if dev: + res[dev] = dev + return res if isinstance(expected, (int, float)): if isinstance(got, np.ndarray) and len(got.shape) == 0: got = float(got) if isinstance(got, (int, float)): if expected == got: return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) - return dict( + res = dict( abs=abs(expected - got), rel=abs(expected - got) / (abs(expected) + 1e-5), sum=abs(expected - got), n=1, dnan=0, ) + if dev: + res[dev] = dev + return res return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if expected.dtype in (np.complex64, np.complex128): if got.dtype == expected.dtype: @@ -1339,6 +1361,8 @@ def max_diff( res: Dict[str, float] = dict( # type: ignore abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm ) + if dev: + res[dev] = dev if hist: if isinstance(hist, bool): hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype) @@ -1352,9 +1376,14 @@ def max_diff( if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor): if verbose >= 6: print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}") + dev = 0 if expected.device == got.device else 1 if _index < begin or (end != -1 and _index >= end): # out of boundary - return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) + if verbose >= 10: + if debug_info: + print("\n".join(debug_info)) + print("[max_diff] out of boundary") + return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0, dev=dev) if expected.dtype in (torch.complex64, torch.complex128): if got.dtype == expected.dtype: got = torch.view_as_real(got) @@ -1448,7 +1477,13 @@ def max_diff( ) res: Dict[str, float] = dict( # type: ignore - abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm + abs=abs_diff, + rel=rel_diff, + sum=sum_diff, + n=n_diff, + dnan=nan_diff, + argm=argm, + dev=dev, ) if hist: if isinstance(hist, bool): @@ -1466,13 +1501,31 @@ def max_diff( ) return res # type: ignore + if isinstance(expected, int) and isinstance(got, torch.Tensor): + # a size + if verbose >= 6: + print(f"[max_diff] int: {string_type(expected)} ? {string_type(got)}") + if got.shape != tuple(): + return dict( # type: ignore + abs=np.inf, + rel=np.inf, + sum=np.inf, + n=np.inf, + dnan=np.inf, + argm=np.inf, + ) + return dict( # type: ignore + abs=abs(expected - got.item()), + rel=abs((expected - got.item()) / max(1, expected)), + sum=abs(expected - got.item()), + n=1, + dnan=0, + ) + if "SquashedNormal" in expected.__class__.__name__: if verbose >= 6: print(f"[max_diff] SquashedNormal: {string_type(expected)} ? {string_type(got)}") - values = ( - expected.mean.detach().to("cpu"), - expected.scale.detach().to("cpu"), - ) + values = (expected.mean, expected.scale) return max_diff(values, got, debug_info=_debug("SquashedNormal"), **_dkws) if expected.__class__ in torch.utils._pytree.SUPPORTED_NODES: @@ -1677,7 +1730,7 @@ def max_diff( raise AssertionError( f"Not implemented with implemented with expected=" - f"{string_type(expected)}, got={string_type(got)},\n" + f"{string_type(expected)} ({type(expected)}), got={string_type(got)},\n" f"level={level}" ) @@ -1685,6 +1738,9 @@ def max_diff( def string_diff(diff: Dict[str, Any]) -> str: """Renders discrepancies return by :func:`max_diff` into one string.""" # dict(abs=, rel=, sum=, n=n_diff, dnan=) + if "dev" in diff: + ddiff = {k: v for k, v in diff.items() if k != "dev"} + return f"{string_diff(ddiff)}, dev={diff['dev']}" suffix = "" if "rep" in diff: rows = [] diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 7477c8cd..7f90c3cc 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -338,6 +338,7 @@ class InferenceSessionForTorch(_InferenceSession): :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions` :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions` :param use_training_api: use onnxruntime-traning API + :param cpu_output: if True, force the outputs to be on CPU """ def __init__( @@ -353,6 +354,7 @@ def __init__( optimized_model_filepath: Optional[str] = None, disable_aot_function_inlining: Optional[bool] = None, use_training_api: Optional[bool] = None, + cpu_outputs: bool = False, ): super().__init__( sess, @@ -367,6 +369,7 @@ def __init__( disable_aot_function_inlining=disable_aot_function_inlining, use_training_api=use_training_api, ) + self.cpu_outputs = cpu_outputs def _get_ortvalues_from_torch_tensors( self, tensors: Tuple[torch.Tensor, ...], n_outputs: int @@ -490,23 +493,39 @@ def run_dlpack( feeds is a dictionary of :class:`torch.Tensor`. The output device is CPU even if the outputs are on CUDA. """ - new_feeds = {} + input_names = [] + values = ORTC.OrtValueVector() + device = -1 for k, v in feeds.items(): + device = max(device, v.get_device()) assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized" if not v.is_contiguous(): v = v.contiguous() if v.dtype == torch.bool: # It does not work with dlpack # unless onnxruntime updates the version it is using. - new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type( + v = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type( v.detach().numpy(), onnx.TensorProto.BOOL ) else: - new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False) + v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False) + input_names.append(k) + values.push_back(v) if self.nvtx: - self.torch.cuda.nvtx.range_push("run_with_ort_values") - ort_outputs = self.sess._sess.run_with_ort_values( - new_feeds, output_names or self.output_names, self.run_options + self.torch.cuda.nvtx.range_push("run_with_ortvaluevector") + + # ort_outputs = self.sess._sess.run_with_ort_values( + # new_feeds, output_names or self.output_names, self.run_options + # ) + ort_outputs = ORTC.OrtValueVector() + out_names = output_names or self.output_names + self.sess._sess.run_with_ortvaluevector( + self.run_options, + input_names, + values, + out_names, + ort_outputs, + [DEVICES[-1 if self.cpu_outputs else device] for o in out_names], ) if self.nvtx: self.torch.cuda.nvtx.range_pop() diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index cba391ef..7d16d8ed 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -413,6 +413,7 @@ def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]: def _get_sess( self, node: Union[ModelProto, NodeProto], inputs: List[Any] ) -> Tuple[ModelProto, _InferenceSession]: + on_cpu = None if isinstance(node, ModelProto): onx = node else: @@ -443,6 +444,8 @@ def _get_sess( voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] onx = self._make_model_proto([node], vinputs, voutputs) + if node.op_type in {"Shape", "Size"}: + on_cpu = True cls = ( InferenceSessionForNumpy @@ -450,8 +453,17 @@ def _get_sess( and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy) else InferenceSessionForTorch ) + if ( + "providers" not in self.session_kwargs or not self.session_kwargs["providers"] + ) and any(hasattr(t, "device") and t.device.type.startswith("cuda") for t in inputs): + sess_kwargs = self.session_kwargs.copy() + sess_kwargs["providers"] = ["CUDAExecutionProvider"] + else: + sess_kwargs = self.session_kwargs + if on_cpu and "CUDAExecutionProvider" in sess_kwargs.get("providers", []): + sess_kwargs["cpu_outputs"] = True try: - sess = cls(onx, **self.session_kwargs) + sess = cls(onx, **sess_kwargs) except ( onnxruntime.capi.onnxruntime_pybind11_state.Fail, onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph, @@ -540,7 +552,11 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L feeds = dict(zip(node.input, inputs)) if "" in feeds: - feeds[""] = np.array([0], dtype=np.float32) + cls = None + for k, v in feeds: + if k != "": + cls = v.__class__ + feeds[""] = cls([0]) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" outputs = list(sess.run(None, feeds)) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index c66c6ffa..47683bac 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -5,6 +5,7 @@ import numpy as np import torch from ..helpers import string_type, string_diff, max_diff, flatten_object +from ..helpers.onnx_helper import pretty_onnx from ..helpers.torch_helper import to_numpy, from_numpy @@ -328,10 +329,26 @@ def post_process(obs): print(df) """ assert callable(run_cls), f"run_cls={run_cls} not a callable" + str_kws = dict(with_shape=True, with_device=True, with_min_max=True) + has_cuda = any( + (isinstance(t, torch.Tensor) and t.is_cuda) + for t in flatten_object([args, kwargs], drop_keys=True) + ) + default_device = None + if has_cuda: + for t in flatten_object([args, kwargs], drop_keys=True): + if t is not None and t.is_cuda: + default_device = t.device + break run_cls_kwargs = { "ir_version": onx.ir_version, "opsets": {d.domain: d.version for d in onx.opset_import}, "verbose": max(verbose - 1, 0), + "providers": ( + ["CUDAExecutionProvider", "CPUExecutionProvider"] + if has_cuda + else ["CPUExecutionProvider"] + ), } run_cls_kwargs = { k: v @@ -369,10 +386,7 @@ def _loop_cmp( ): onnx_results[o] = _check_tensor_(o, r) if verbose: - print( - f"[run_aligned-nx] +res: {o}=" - f"{string_type(r, with_shape=True, with_min_max=True)}" - ) + print(f"[run_aligned-nx] +res: {o}={string_type(r, **str_kws)}") to = mapping_onnx_to_torch.get(o, o) if to in torch_results: @@ -385,12 +399,11 @@ def _loop_cmp( if not ( atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol) ): - skw = dict(with_shape=True, with_min_max=True) raise ValueError( f"discrepancies detected for results [{to}/{o}]: " f"{string_diff(d)}" - f"\n-- torch_results: {string_type(torch_results[to], **skw)}" - f"\n-- onnx_results: {string_type(r, **skw)}" + f"\n-- torch_results: {string_type(torch_results[to], **str_kws)}" + f"\n-- onnx_results: {string_type(r, **str_kws)}" f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" ) return (i, i_onnx, o, to, d) @@ -420,16 +433,18 @@ def _loop_cmp( onnx_results: Dict[str, Any] = {} for init in onx.graph.initializer: # type: ignore positions[init.name] = -1 - onnx_results[init.name] = _check_tensor_( - init.name, - run_cls( - _make_node_from_initializer(init), - **run_cls_kwargs, - ).run( # type: ignore[attr-defined] - None, {} - )[0], - flip_type=True, - ) + t = run_cls( + _make_node_from_initializer(init), + **run_cls_kwargs, + ).run( # type: ignore[attr-defined] + None, {} + )[ + 0 + ] + if default_device and t.numel() >= 1024: + # Let's force its way to cuda (should check the device has well). + t = t.to(default_device) + onnx_results[init.name] = _check_tensor_(init.name, t, flip_type=True) param_name = f"p_{init.name.replace('.', '_')}" if param_name == init.name: continue @@ -471,31 +486,22 @@ def _loop_cmp( print(f"[run_aligned] onnx {len(onnx_results)} constants") print(f"[run_aligned] common {len(mapping_onnx_to_torch)} constants") for k, v in torch_results.items(): - print( - f"[run_aligned-ep] +cst: {k}: " - f"{string_type(v, with_shape=True, with_min_max=True)}" - ) + print(f"[run_aligned-ep] +cst: {k}: {string_type(v, str_kws)}") for k, v in onnx_results.items(): - print( - f"[run_aligned-nx] +ini: {k}: " - f"{string_type(v, with_shape=True, with_min_max=True)}" - ) + print(f"[run_aligned-nx] +ini: {k}: {string_type(v, str_kws)}") onnx_args = list(args) if args else [] if kwargs: onnx_args.extend(flatten_object(kwargs, drop_keys=True)) if verbose: - print(f"[run_aligned] args: {string_type(args, with_shape=True)}") - print(f"[run_aligned] kwargs: {string_type(kwargs, with_shape=True)}") - print(f"[run_aligned] onnx: {string_type(onnx_args, with_shape=True)}") + print(f"[run_aligned] args: {string_type(args, **str_kws)}") + print(f"[run_aligned] kwargs: {string_type(kwargs, **str_kws)}") + print(f"[run_aligned] onnx: {string_type(onnx_args, **str_kws)}") print(f"[run_aligned] walks through {len(onx.graph.input)} onnx inputs") for inp, v in zip(onx.graph.input, onnx_args): onnx_results[inp.name] = _check_tensor_(inp.name, v if use_tensor else to_numpy(v)) if verbose: - print( - f"[run_aligned-nx] +inp: {inp.name}: " - f"{string_type(v, with_shape=True, with_min_max=True)}" - ) + print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}") for i, node in enumerate(ep.graph.nodes): if verbose: @@ -516,16 +522,13 @@ def _loop_cmp( ) if verbose: t = torch_results[node.name] - print( - f"[run_aligned-ep] +plh: {node.name}=" - f"{string_type(t, with_shape=True, with_min_max=True)}" - ) + print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}") continue raise AssertionError( f"unable to process node {node.op} -> {node.name!r} " f"not in {sorted(onnx_results)}, " - f"args={string_type(args, with_shape=True)}, " - f"kwargs={string_type(kwargs, with_shape=True)}, " + f"args={string_type(args, **str_kws)}, " + f"kwargs={string_type(kwargs, **str_kws)}, " f"onx.graph.input={[i.name for i in onx.graph.input]}" ) @@ -543,10 +546,7 @@ def _loop_cmp( torch_results[k] = v if verbose: for k, v in zip(outputs, new_outputs): - print( - f"[run_aligned-ep] +res: {k}=" - f"{string_type(v, with_shape=True, with_min_max=True)}" - ) + print(f"[run_aligned-ep] +res: {k}={string_type(v, **str_kws)}") max_pos = -2 for n in outputs: @@ -566,6 +566,22 @@ def _loop_cmp( ref = run_cls(node, **run_cls_kwargs) feeds = {k: onnx_results[k] for k in node.input} res = ref.run(None, feeds) # type: ignore[attr-defined] + assert ( + not has_cuda + or not any(t is not None and t.is_cuda for t in feeds.values()) + or any( + t is not None + and t.is_cuda + and t.dtype in {torch.float32, torch.float16, torch.bfloat16} + for t in res + ) + or node.op_type in {"Shape", "Size"} # on CPU no matter what + ), ( + f"One input is on cuda but there is no float output on cuda, " + f"feeds={string_type(feeds, with_device=True, with_shape=True)}, " + f"res={string_type(res, with_device=True, with_shape=True)}, " + f"node is {pretty_onnx(node)}" + ) for o, r in zip(node.output, res): tmp = _loop_cmp( mapping_onnx_to_torch, @@ -583,6 +599,8 @@ def _loop_cmp( last_position = max_pos + 1 # complete the execution of the onnx graph + if verbose: + print(f"[run_aligned] complete execution of onnx graph from pos={last_position}") for i_onnx in range(last_position, len(onx.graph.node)): node = onx.graph.node[i_onnx] if verbose: From 6a1f2f7b4c6f19c508caf1e839e3efc02dabbbbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 16:55:37 +0100 Subject: [PATCH 06/22] fix ut --- _unittests/ut_helpers/test_bench_run.py | 31 +++++++++++++++---- .../test_onnxruntime_evaluator.py | 5 ++- onnx_diagnostic/reference/ort_evaluator.py | 6 ++-- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/_unittests/ut_helpers/test_bench_run.py b/_unittests/ut_helpers/test_bench_run.py index 12dbae6b..5e831826 100644 --- a/_unittests/ut_helpers/test_bench_run.py +++ b/_unittests/ut_helpers/test_bench_run.py @@ -109,35 +109,51 @@ def test_make_configs_replace(self): def test_max_diff(self): self.assertEqual( max_diff(torch.Tensor([1, 2]), torch.Tensor([1, 2])), - {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 2.0, "dnan": 0.0, "argm": (0,)}, + { + "abs": 0.0, + "rel": 0.0, + "sum": 0.0, + "n": 2.0, + "dnan": 0.0, + "argm": (0,), + "dev": 0, + }, ) self.assertEqual( max_diff( (torch.Tensor([1, 2]),), (torch.Tensor([1, 2])), ), - {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 2.0, "dnan": 0.0, "argm": (0,)}, + { + "abs": 0.0, + "rel": 0.0, + "sum": 0.0, + "n": 2.0, + "dnan": 0.0, + "argm": (0,), + "dev": 0, + }, ) self.assertEqual( max_diff( (torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)), (torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)), ), - {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 4.0, "dnan": 0.0}, + {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 4.0, "dnan": 0.0, "dev": 0}, ) self.assertEqual( max_diff( {"a": torch.Tensor([1, 2])}, {"a": torch.Tensor([1, 2])}, ), - {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 2.0, "dnan": 0.0}, + {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 2.0, "dnan": 0.0, "dev": 0}, ) self.assertEqual( max_diff( {"a": torch.Tensor([1, 2])}, [torch.Tensor([1, 2])], ), - {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 2.0, "dnan": 0.0}, + {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 2.0, "dnan": 0.0, "dev": 0}, ) self.assertEqual( max_diff( @@ -150,6 +166,7 @@ def test_max_diff(self): "n": 2.0, "rel": 0.9999999997999001, "sum": 9999999998.0, + "dev": 0, }, ) @@ -164,7 +181,9 @@ def test_max_diff_dynamic_cache(self): flatten=True, verbose=10, ) - self.assertEqual(md, {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 10.0, "dnan": 0}) + self.assertEqual( + md, {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 10.0, "dnan": 0, "dev": 0} + ) if __name__ == "__main__": diff --git a/_unittests/ut_reference/test_onnxruntime_evaluator.py b/_unittests/ut_reference/test_onnxruntime_evaluator.py index 738b0ea3..7906d586 100644 --- a/_unittests/ut_reference/test_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_onnxruntime_evaluator.py @@ -4,7 +4,7 @@ import onnx.helper as oh import torch import onnxruntime -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings from onnx_diagnostic.helpers.onnx_helper import from_array_extended from onnx_diagnostic.reference import ( OnnxruntimeEvaluator, @@ -22,6 +22,7 @@ class TestOnnxruntimeEvaluator(ExtTestCase): + @ignore_warnings(FutureWarning) def test_ort_eval_scan_cdist_add(self): def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor): @@ -69,6 +70,7 @@ def forward(self, x): got = orte.run(None, {name: x.numpy()})[0] self.assertEqualArray(expected, got) + @ignore_warnings((UserWarning, FutureWarning)) def test_ort_eval_cond(self): import torch @@ -180,6 +182,7 @@ def test_constant_bool_input(self): self.assertEqual(got.dtype, torch.bool) self.assertEqual(got[0], True) + @hide_stdout() def test_ort_eval_loop(self): model = torch.nn.EmbeddingBag(num_embeddings=49157, embedding_dim=32, mode="sum") a = torch.tensor([[39906, 39906]]).long() diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index 7d16d8ed..d168a7b1 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -455,12 +455,12 @@ def _get_sess( ) if ( "providers" not in self.session_kwargs or not self.session_kwargs["providers"] - ) and any(hasattr(t, "device") and t.device.type.startswith("cuda") for t in inputs): + ) and any(hasattr(t, "is_cuda") and t.is_cuda for t in inputs): sess_kwargs = self.session_kwargs.copy() sess_kwargs["providers"] = ["CUDAExecutionProvider"] else: - sess_kwargs = self.session_kwargs - if on_cpu and "CUDAExecutionProvider" in sess_kwargs.get("providers", []): + sess_kwargs = self.session_kwargs or {} + if on_cpu and "CUDAExecutionProvider" in (sess_kwargs.get("providers", []) or []): sess_kwargs["cpu_outputs"] = True try: sess = cls(onx, **sess_kwargs) From e85c69721e56a380ac2745117d956ca87cbf3530 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 18:02:15 +0100 Subject: [PATCH 07/22] doc --- onnx_diagnostic/torch_onnx/sbs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 47683bac..e70e5aab 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -257,7 +257,7 @@ def post_process(obs): import onnx import torch - from onnx_diagnostic.torch_export_patches.patch_inputs.use_dyn_not_str + from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str class Model(torch.nn.Module): From 0c25bc95585d64bd6f678f6375a531c7575bb95c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 17:23:26 +0000 Subject: [PATCH 08/22] sbs --- _unittests/ut_tasks/try_export.py | 4 +- .../ut_torch_models/test_validate_models.py | 4 +- _unittests/ut_xrun_doc/test_command_lines.py | 8 + onnx_diagnostic/_command_lines_parser.py | 169 ++++++++++++++++-- onnx_diagnostic/torch_onnx/sbs.py | 111 +++++++++--- 5 files changed, 258 insertions(+), 38 deletions(-) diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index 75f7e173..0397c310 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -86,6 +86,8 @@ def _config_reduction(config, task): hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device), grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), ) + print("-- save inputs") + torch.save(inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.pt")) print(f"-- inputs: {self.string_type(inputs, with_shape=True)}") # this is too long @@ -120,7 +122,7 @@ def _config_reduction(config, task): filename=filename, exporter=exporter, verbose=1, - save_ep=fileep, + save_ep=(fileep, 2**35), target_opset=22, optimize=True, ) diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index 6ecc2653..7f0138ee 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -46,8 +46,8 @@ def test_validate_tiny_llms_bfloat16(self): self.clean_dump() @unittest.skipIf(torch29_and_tr_main, "combination not working") - @requires_transformers("4.53") - @requires_torch("2.8.99") + @requires_transformers("4.57") # 4.53 works for some jobs fails due to no space left + @requires_torch("2.9.99") # 2.9 works for some jobs fails due to no space left @requires_experimental() @hide_stdout() def test_validate_microsoft_phi4_reasoning(self): diff --git a/_unittests/ut_xrun_doc/test_command_lines.py b/_unittests/ut_xrun_doc/test_command_lines.py index b478f538..dff47c1e 100644 --- a/_unittests/ut_xrun_doc/test_command_lines.py +++ b/_unittests/ut_xrun_doc/test_command_lines.py @@ -9,6 +9,7 @@ get_parser_find, get_parser_lighten, get_parser_print, + get_parser_sbs, get_parser_stats, get_parser_unlighten, get_parser_validate, @@ -79,6 +80,13 @@ def test_parser_agg(self): text = st.getvalue() self.assertIn("--recent", text) + def test_parser_sbs(self): + st = StringIO() + with redirect_stdout(st): + get_parser_sbs().print_help() + text = st.getvalue() + self.assertIn("--model", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 854943bf..5b0c28b9 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -4,6 +4,7 @@ import re import sys import textwrap +import time import onnx from typing import Any, Dict, List, Optional, Union from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction @@ -1104,6 +1105,146 @@ def _cmd_agg(argv: List[Any]): print(f"Wrote {args.output!r}") +def get_parser_sbs() -> ArgumentParser: + parser = ArgumentParser( + prog="side-by-side (sbs)", + description=textwrap.dedent( + """ + Compares the intermediate outputs between the exported program and + the exported onnx model. It assumes some names are common. + The execution of the exported program and the onnx model + are done in parallel. The device is the one used to store the + model and the inputs.s + """ + ), + epilog="Where do discrepancies start? This function tries to answer that question.", + ) + parser.add_argument( + "-i", + "--inputs", + type=str, + required=True, + help="model inputs saved with torch.save", + ) + parser.add_argument( + "--ep", + type=str, + required=True, + help="exported program saved with torch.export.save", + ) + parser.add_argument( + "-m", + "--onnx", + type=str, + required=True, + help="exported model in onnx format", + ) + parser.add_argument( + "-o", + "--output", + type=str, + required=True, + help="output name to stored what the command line produces, " + "it should be an excel file", + ) + parser.add_argument( + "--atol", + default=1e-5, + required=False, + help="absolute tolerance", + ) + parser.add_argument( + "--rtol", + default=1e-5, + required=False, + help="relative tolerance", + ) + parser.add_argument( + "-v", + "--verbose", + default=0, + required=False, + help="verbosity", + ) + return parser + + +def _cmd_sbs(argv: List[Any]): + import pandas + import torch + from .helpers import string_type + from .torch_onnx.sbs import run_aligned + from .reference import OnnxruntimeEvaluator + + parser = get_parser_sbs() + args = parser.parse_args(argv[1:]) + + def _size(name): + s = os.stat(name).st_size + return f"{s / 2**20:1.3f} Mb" + + print("-- side by side") + print(f"-- ep: {_size(args.ep)}: {args.ep}") + print(f"-- inputs: {_size(args.inputs)}: {args.inputs}") + print(f"-- onnx: {_size(args.onnx)}: {args.onnx}") + print(f"-- output: {args.output}") + + print(f"-- load inputs {args.inputs!r}") + begin = time.perf_counter() + inputs = torch.load(args.inputs) + s = string_type(inputs, with_shape=True, with_device=True) + print(f"-- done in {time.perf_counter() - begin:1.1f}s - {s}") + + if isinstance(inputs, dict) and len(inputs) == 2 and set(inputs) == {"args", "kwargs"}: + margs = inputs["args"] + mkwargs = inputs["kwargs"] + elif isinstance(inputs, tuple): + margs = inputs + mkwargs = {} + elif isinstance(inputs, dict): + margs = tuple() + mkwargs = inputs + else: + raise ValueError( + f"Unable to infer args, kwargs from inputs {string_type(inputs, with_shape=True)}" + ) + + print(f"-- load ep {args.ep!r}") + begin = time.perf_counter() + ep = torch.export.load(args.ep) + print(f"-- done in {time.perf_counter() - begin:1.1f}s") + + print(f"-- load onnx {args.onnx!r}") + begin = time.perf_counter() + onx = onnx.load(args.onnx) + print(f"-- done in {time.perf_counter() - begin:1.1f}s") + + def post_process(obs): + dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs)) + dobs["err_abs"] = obs[-1]["abs"] + dobs["err_rel"] = obs[-1]["rel"] + return dobs + + print("-- starts side-by-side") + data = [] + for obs in run_aligned( + ep, + onx, + run_cls=OnnxruntimeEvaluator, + atol=float(args.atol), + rtol=float(args.rtol), + verbose=int(args.verbose), + args=margs, + kwargs=mkwargs, + use_tensor=True, + exc=False, + ): + data.append(post_process(obs)) + df = pandas.DataFrame(data) + df.to_excel(args.output) + print("-- done") + + def get_main_parser() -> ArgumentParser: parser = ArgumentParser( prog="onnx_diagnostic", @@ -1120,6 +1261,7 @@ def get_main_parser() -> ArgumentParser: find - find node consuming or producing a result lighten - makes an onnx model lighter by removing the weights, print - prints the model on standard output + sbs - compares an exported program and a onnx model stats - produces statistics on a model unlighten - restores an onnx model produces by the previous experiment validate - validate a model @@ -1135,6 +1277,7 @@ def get_main_parser() -> ArgumentParser: "find", "lighten", "print", + "sbs", "stats", "unlighten", "validate", @@ -1146,15 +1289,16 @@ def get_main_parser() -> ArgumentParser: def main(argv: Optional[List[Any]] = None): fcts = dict( + agg=_cmd_agg, + config=_cmd_config, + exportsample=_cmd_export_sample, + find=_cmd_find, lighten=_cmd_lighten, - unlighten=_cmd_unlighten, print=_cmd_print, - find=_cmd_find, - config=_cmd_config, - validate=_cmd_validate, + sbs=_cmd_sbs, stats=_cmd_stats, - agg=_cmd_agg, - exportsample=_cmd_export_sample, + unlighten=_cmd_unlighten, + validate=_cmd_validate, ) if argv is None: @@ -1169,15 +1313,16 @@ def main(argv: Optional[List[Any]] = None): parser.parse_args(argv) else: parsers = dict( + agg=get_parser_agg, + config=get_parser_config, + exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator] + find=get_parser_find, lighten=get_parser_lighten, - unlighten=get_parser_unlighten, print=get_parser_print, - find=get_parser_find, - config=get_parser_config, - validate=get_parser_validate, + sbs=get_parser_sbs, stats=get_parser_stats, - agg=get_parser_agg, - exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator] + unlighten=get_parser_unlighten, + validate=get_parser_validate, ) cmd = argv[0] if cmd not in parsers: diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 47683bac..e7dda3f4 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -108,7 +108,14 @@ def run_fx_node( return args if node.op == "call_function": assert callable(node.target), f"{node.target!r} not callable in node {node!r}" - outputs = node.target(*args, **(kwargs or {})) + try: + outputs = node.target(*args, **(kwargs or {})) + except RuntimeError as e: + raise RuntimeError( + f"Unable to run node {node!r}, target={node.target!r}, " + f"args={string_type(args, with_shape=True, with_device=True)}, " + f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}" + ) from e validate_fx_outputs(node, outputs) return outputs raise NotImplementedError( @@ -130,6 +137,8 @@ def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any: return ref if ref is None: return None + if isinstance(ref, torch.layout): + return ref raise NotImplementedError(f"Unable to process args type {type(ref)}") @@ -168,6 +177,7 @@ def run_aligned( atol: Optional[float] = None, rtol: Optional[float] = None, verbose: int = 0, + exc: bool = True, ) -> Iterator[Tuple[Any, ...]]: """ Runs in parallel both the exported program @@ -185,6 +195,7 @@ def run_aligned( :param atol: absolute tolerance :param rtol: relative tolerance :param verbose: verbosity level + :param exc: stops if an exception :return: a list of tuples containing the results, they come in tuple Example: @@ -382,7 +393,16 @@ def _make_node_from_initializer(proto: onnx.TensorProto) -> onnx.NodeProto: return oh.make_node("Constant", [], [proto.name], value=proto) def _loop_cmp( - mapping_onnx_to_torch, torch_results, onnx_results, o, r, verbose, atol, rtol + mapping_onnx_to_torch, + torch_results, + onnx_results, + o, + r, + verbose, + atol, + rtol, + i, + i_onnx, ): onnx_results[o] = _check_tensor_(o, r) if verbose: @@ -399,13 +419,20 @@ def _loop_cmp( if not ( atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol) ): - raise ValueError( - f"discrepancies detected for results [{to}/{o}]: " - f"{string_diff(d)}" - f"\n-- torch_results: {string_type(torch_results[to], **str_kws)}" - f"\n-- onnx_results: {string_type(r, **str_kws)}" - f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" - ) + if exc: + raise ValueError( + f"discrepancies detected for results [{to}/{o}]: " + f"{string_diff(d)}" + f"\n-- torch_results: {string_type(torch_results[to], **str_kws)}" + f"\n-- onnx_results: {string_type(r, **str_kws)}" + f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" + ) + else: + print( + f"[run_align-dx] discrepancies " + f"{string_diff(d, with_shape=True, with_device=True)} - " + f"[{to}/{o}]" + ) return (i, i_onnx, o, to, d) return None @@ -445,6 +472,11 @@ def _loop_cmp( # Let's force its way to cuda (should check the device has well). t = t.to(default_device) onnx_results[init.name] = _check_tensor_(init.name, t, flip_type=True) + if init.name.startswith("init"): + # not a weight + continue + + # quick fixes param_name = f"p_{init.name.replace('.', '_')}" if param_name == init.name: continue @@ -454,6 +486,15 @@ def _loop_cmp( ) onnx_results[param_name] = onnx_results[init.name] + param_name = f"{init.name.replace('.', '_')}".split("::")[0] + if param_name == init.name: + continue + assert param_name not in onnx_results, ( + f"Some confusion may happen because {init.name!r} -> {param_name!r} " + f"and onnx_results has {sorted(onnx_results)}" + ) + onnx_results[param_name] = onnx_results[init.name] + if verbose: print(f"[run_aligned] handles common {len(onnx_results)} initializer from torch") # we should be careful, torch may modified inplace the weights, @@ -524,13 +565,19 @@ def _loop_cmp( t = torch_results[node.name] print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}") continue - raise AssertionError( - f"unable to process node {node.op} -> {node.name!r} " - f"not in {sorted(onnx_results)}, " - f"args={string_type(args, **str_kws)}, " - f"kwargs={string_type(kwargs, **str_kws)}, " - f"onx.graph.input={[i.name for i in onx.graph.input]}" - ) + if exc: + raise AssertionError( + f"unable to process node {node.op} -> {node.name!r}, " + f"possible candiate are " + f"{sorted(p for p in onnx_results if node.name in p)}, " + f"not in {sorted(onnx_results)}, " + f"args={string_type(args, **str_kws)}, " + f"kwargs={string_type(kwargs, **str_kws)}, " + f"onx.graph.input={[i.name for i in onx.graph.input]}" + ) + elif verbose: + print(f"[run_aligned] unable to {node.name!r} among onnx weights") + continue outputs = [node.name] if isinstance(node.name, str) else list(node.name) args, kwargs = prepare_args_kwargs(torch_results, node) @@ -569,13 +616,20 @@ def _loop_cmp( assert ( not has_cuda or not any(t is not None and t.is_cuda for t in feeds.values()) - or any( - t is not None - and t.is_cuda - and t.dtype in {torch.float32, torch.float16, torch.bfloat16} - for t in res - ) + or any(t is not None and t.is_cuda for t in res) or node.op_type in {"Shape", "Size"} # on CPU no matter what + or node.op_type + in { + "Add", + "Concat", + "Div", + "Gather", + "Mul", + "Range", + "Squeeze", + "Sub", + "Unsqueeze", + } # not sure, could be about shapes ), ( f"One input is on cuda but there is no float output on cuda, " f"feeds={string_type(feeds, with_device=True, with_shape=True)}, " @@ -592,6 +646,8 @@ def _loop_cmp( verbose, atol, rtol, + i, + i_onnx, ) if tmp is not None: yield tmp @@ -613,7 +669,16 @@ def _loop_cmp( res = ref.run(None, feeds) # type: ignore[attr-defined] for o, r in zip(node.output, res): tmp = _loop_cmp( - mapping_onnx_to_torch, torch_results, onnx_results, o, r, verbose, atol, rtol + mapping_onnx_to_torch, + torch_results, + onnx_results, + o, + r, + verbose, + atol, + rtol, + i, + i_onnx, ) if tmp is not None: yield tmp From 30459e542f291da6af05d5b544e92855cbe8da9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 18:26:44 +0100 Subject: [PATCH 09/22] doc --- onnx_diagnostic/torch_onnx/sbs.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 4918888d..4f9a793c 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -338,6 +338,16 @@ def post_process(obs): print("final results") df = pandas.DataFrame(results) print(df) + + A command line can also be run: + + .. code-block:: bash + + python -m onnx_diagnostic sbs -i .input.pt \\ + --ep .pt2 \\ + -m .onnx \\ + -o results.xlsx \\ + -v 1 --atol=0.1 --rtol=1 """ assert callable(run_cls), f"run_cls={run_cls} not a callable" str_kws = dict(with_shape=True, with_device=True, with_min_max=True) From 8fb3ec0eb0d3c9d4c1916407fc0c47cb5e26c40a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 14 Nov 2025 18:36:17 +0100 Subject: [PATCH 10/22] last --- CHANGELOGS.rst | 6 +++++- _doc/index.rst | 2 +- onnx_diagnostic/__init__.py | 2 +- onnx_diagnostic/torch_onnx/sbs.py | 2 +- pyproject.toml | 2 +- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 70d18e8a..96ba2c91 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,10 +1,14 @@ Change Logs =========== -0.8.2 +0.8.3 +++++ * :pr:`304`: improves side-by-side comparison + +0.8.2 ++++++ + * :pr:`303`: fix inputs for summarization, feature extraction tasks * :pr:`302`: adds helpers to analyse onnxruntime profiling * :pr:`297`: experiment around a higher ops ``loop_for`` diff --git a/_doc/index.rst b/_doc/index.rst index b1f7a98b..c41e655a 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -239,8 +239,8 @@ The function replaces dynamic dimensions defined as strings by Older versions ============== +* `0.8.3 <../v0.8.3/index.html>`_ * `0.8.2 <../v0.8.2/index.html>`_ -* `0.8.1 <../v0.8.1/index.html>`_ * `0.7.16 <../v0.7.16/index.html>`_ * `0.6.3 <../v0.6.3/index.html>`_ * `0.5.0 <../v0.5.0/index.html>`_ diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index 9f99501a..e8c842df 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.8.2" +__version__ = "0.8.3" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 4f9a793c..b4e0f89b 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -338,7 +338,7 @@ def post_process(obs): print("final results") df = pandas.DataFrame(results) print(df) - + A command line can also be run: .. code-block:: bash diff --git a/pyproject.toml b/pyproject.toml index deb8af51..2e5e4107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "onnx-diagnostic" -version = "0.8.2" +version = "0.8.3" description = "Tools to help converting pytorch models into ONNX." readme = "README.rst" authors = [ From f8bca95d23e1adc2ce0806a3d9752b839cf34c54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 00:40:06 +0100 Subject: [PATCH 11/22] fix --- _unittests/ut_xrun_doc/test_command_lines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_xrun_doc/test_command_lines.py b/_unittests/ut_xrun_doc/test_command_lines.py index dff47c1e..5317190f 100644 --- a/_unittests/ut_xrun_doc/test_command_lines.py +++ b/_unittests/ut_xrun_doc/test_command_lines.py @@ -85,7 +85,7 @@ def test_parser_sbs(self): with redirect_stdout(st): get_parser_sbs().print_help() text = st.getvalue() - self.assertIn("--model", text) + self.assertIn("--onnx", text) if __name__ == "__main__": From e7399a1dfefde582b7b89fd60ec7657f1f673a0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 02:09:12 +0100 Subject: [PATCH 12/22] fix --- _unittests/ut_torch_onnx/test_sbs.py | 72 ++++++++++++++++------ onnx_diagnostic/_command_lines_parser.py | 2 +- onnx_diagnostic/helpers/helper.py | 8 +-- onnx_diagnostic/reference/ort_evaluator.py | 6 +- onnx_diagnostic/torch_onnx/sbs.py | 6 +- 5 files changed, 67 insertions(+), 27 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index be82aaa6..5fe1a7a8 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -1,4 +1,5 @@ import unittest +import onnx from onnx_diagnostic.ext_test_case import ( ExtTestCase, hide_stdout, @@ -9,11 +10,7 @@ from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_onnx.sbs import run_aligned - -try: - from experimental_experiment.torch_interpreter import to_onnx -except ImportError: - to_onnx = None +from onnx_diagnostic.export.api import to_onnx class TestSideBySide(ExtTestCase): @@ -41,7 +38,7 @@ def forward(self, x): ep = self.torch.export.export( Model(), (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},) ) - onx = to_onnx(ep) + onx = to_onnx(ep, exporter="custom").model_proto results = list( run_aligned( ep, @@ -71,10 +68,12 @@ def forward(self, x): ep = self.torch.export.export( Model(), (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},) ) - epo = self.torch.onnx.export( - ep, (x,), dynamic_shapes=({0: self.torch.export.Dim("batch")},), dynamo=True - ) - onx = epo.model_proto + onx = to_onnx( + ep, + (x,), + dynamic_shapes=({0: self.torch.export.Dim("batch")},), + exporter="onnx-dynamo", + ).model_proto results = list( run_aligned( ep, @@ -105,9 +104,7 @@ def forward(self, x): ep = self.torch.export.export( Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - epo = self.torch.onnx.export( - Model(), (), kwargs=inputs, dynamic_shapes=ds, dynamo=True - ) + epo = to_onnx(Model(), (), kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo") onx = epo.model_proto results = list( run_aligned( @@ -139,7 +136,7 @@ def forward(self, x): ep = self.torch.export.export( Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - onx = to_onnx(ep) + onx = to_onnx(ep, exporter="custom").model_proto results = list( run_aligned( ep, @@ -170,7 +167,7 @@ def forward(self, x): ep = self.torch.export.export( Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - onx = to_onnx(ep) + onx = to_onnx(ep, exporter="custom").model_proto results = list( run_aligned( ep, @@ -204,7 +201,7 @@ def forward(self, x): ep = self.torch.export.export( Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - onx = to_onnx(ep) + onx = to_onnx(ep, exporter="custom").model_proto results = list( run_aligned( ep, @@ -240,7 +237,7 @@ def forward(self, x): ep = self.torch.export.export( Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - onx = to_onnx(ep) + onx = to_onnx(ep, exporter="custom").model_proto results = list( run_aligned( ep, @@ -275,7 +272,7 @@ def forward(self, x): ep = self.torch.export.export( Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - onx = to_onnx(ep) + onx = to_onnx(ep, exporter="custom").model_proto results = list( run_aligned( ep, @@ -291,6 +288,45 @@ def forward(self, x): self.assertEqual(len(results), 7) self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0]) + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_model_with_weights(self): + torch = self.torch + + class Model(self.torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32 + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(32, 1) # hidden → output + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + inputs = dict(x=self.torch.randn((5, 10))) + ds = dict(x={0: "batch"}) + Model()(**inputs) + ep = self.torch.export.export( + Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + filename = self.get_dump_file("test_sbs_model_with_weights.onnx") + to_onnx(ep, exporter="custom", filename=filename) + onx = onnx.load(filename) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + verbose=11, + use_tensor=True, + ), + ) + self.assertEqual(len(results), 7) + self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 5b0c28b9..4b9d4585 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1230,7 +1230,7 @@ def post_process(obs): for obs in run_aligned( ep, onx, - run_cls=OnnxruntimeEvaluator, + run_cls=OnnxruntimeEvaluator, # type: ignore[arg-type] atol=float(args.atol), rtol=float(args.rtol), verbose=int(args.verbose), diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 1c19840e..5b57498a 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1207,7 +1207,7 @@ def max_diff( if "dev" in d: if dd is None: dd = d["dev"] - else: + elif d["dev"] is not None: dd += d["dev"] res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn) @@ -1264,7 +1264,7 @@ def max_diff( # out of boundary res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) if dev: - res[dev] = dev + res["dev"] = dev return res if isinstance(expected, (int, float)): if isinstance(got, np.ndarray) and len(got.shape) == 0: @@ -1280,7 +1280,7 @@ def max_diff( dnan=0, ) if dev: - res[dev] = dev + res["dev"] = dev return res return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if expected.dtype in (np.complex64, np.complex128): @@ -1362,7 +1362,7 @@ def max_diff( abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm ) if dev: - res[dev] = dev + res["dev"] = dev if hist: if isinstance(hist, bool): hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype) diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index d168a7b1..ed8b5f34 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -553,9 +553,13 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L feeds = dict(zip(node.input, inputs)) if "" in feeds: cls = None - for k, v in feeds: + for k, v in feeds.items(): if k != "": cls = v.__class__ + break + assert ( + cls is not None + ), f"Unable to get input class (array or tensor), feeds={string_type(feeds)}" feeds[""] = cls([0]) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index b4e0f89b..7bf612f0 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -537,9 +537,9 @@ def _loop_cmp( print(f"[run_aligned] onnx {len(onnx_results)} constants") print(f"[run_aligned] common {len(mapping_onnx_to_torch)} constants") for k, v in torch_results.items(): - print(f"[run_aligned-ep] +cst: {k}: {string_type(v, str_kws)}") + print(f"[run_aligned-ep] +cst: {k}: {string_type(v, **str_kws)}") for k, v in onnx_results.items(): - print(f"[run_aligned-nx] +ini: {k}: {string_type(v, str_kws)}") + print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}") onnx_args = list(args) if args else [] if kwargs: @@ -578,7 +578,7 @@ def _loop_cmp( if exc: raise AssertionError( f"unable to process node {node.op} -> {node.name!r}, " - f"possible candiate are " + f"possible candidate are " f"{sorted(p for p in onnx_results if node.name in p)}, " f"not in {sorted(onnx_results)}, " f"args={string_type(args, **str_kws)}, " From 82e58d9c2cadfd5f8fc81cde8825d3260cc7cf58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 10:20:21 +0100 Subject: [PATCH 13/22] fix sbs --- _unittests/ut_torch_onnx/test_sbs.py | 7 +- onnx_diagnostic/_command_lines_parser.py | 10 +-- onnx_diagnostic/torch_onnx/sbs.py | 89 +++++++++--------------- 3 files changed, 42 insertions(+), 64 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 5fe1a7a8..86a3b3ba 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -1,4 +1,5 @@ import unittest +import pandas import onnx from onnx_diagnostic.ext_test_case import ( ExtTestCase, @@ -9,7 +10,7 @@ ) from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -from onnx_diagnostic.torch_onnx.sbs import run_aligned +from onnx_diagnostic.torch_onnx.sbs import run_aligned, post_process_run_aligned_obs from onnx_diagnostic.export.api import to_onnx @@ -324,8 +325,12 @@ def forward(self, x): use_tensor=True, ), ) + pandas.DataFrame(list(map(post_process_run_aligned_obs, results))).to_excel( + self.get_dump_file("test_sbs_model_with_weights.xlsx") + ) self.assertEqual(len(results), 7) self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0]) + self.clean_dump() if __name__ == "__main__": diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 4b9d4585..b5c614bb 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1173,7 +1173,7 @@ def _cmd_sbs(argv: List[Any]): import pandas import torch from .helpers import string_type - from .torch_onnx.sbs import run_aligned + from .torch_onnx.sbs import run_aligned, post_process_run_aligned_obs from .reference import OnnxruntimeEvaluator parser = get_parser_sbs() @@ -1219,12 +1219,6 @@ def _size(name): onx = onnx.load(args.onnx) print(f"-- done in {time.perf_counter() - begin:1.1f}s") - def post_process(obs): - dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs)) - dobs["err_abs"] = obs[-1]["abs"] - dobs["err_rel"] = obs[-1]["rel"] - return dobs - print("-- starts side-by-side") data = [] for obs in run_aligned( @@ -1239,7 +1233,7 @@ def post_process(obs): use_tensor=True, exc=False, ): - data.append(post_process(obs)) + data.append(post_process_run_aligned_obs(obs)) df = pandas.DataFrame(data) df.to_excel(args.output) print("-- done") diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 7bf612f0..97dd9b57 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -157,6 +157,17 @@ def prepare_args_kwargs( return new_args, new_kwargs +def post_process_run_aligned_obs(obs: Dict[str, Any]) -> Dict[str, Union[str, float, int]]: + """ + Flattens an observations produced by function + :func:`onnx_diagnostic.torch_onnx.sbs.run_aligned`. + """ + dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs)) + dobs["err_abs"] = obs[-1]["abs"] + dobs["err_rel"] = obs[-1]["rel"] + return dobs + + def run_aligned( ep: torch.export.ExportedProgram, onx: Union[onnx.ModelProto, onnx.FunctionProto], @@ -210,7 +221,10 @@ def run_aligned( # This can be replace by any runtime taking NodeProto as an input. ExtendedReferenceEvaluator as ReferenceEvaluator, ) - from onnx_diagnostic.torch_onnx.sbs import run_aligned + from onnx_diagnostic.torch_onnx.sbs import ( + run_aligned, + post_process_run_aligned_obs, + ) class Model(torch.nn.Module): @@ -222,13 +236,6 @@ def forward(self, x): return ru - def post_process(obs): - dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs)) - dobs["err_abs"] = obs[-1]["abs"] - dobs["err_rel"] = obs[-1]["rel"] - return dobs - - x = torch.randn((5, 4)) Model()(x) # to make sure the model is running ep = torch.export.export( @@ -239,7 +246,7 @@ def post_process(obs): ).model_proto results = list( map( - post_process, + post_process_run_aligned_obs, run_aligned( ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1 ), @@ -303,7 +310,10 @@ def forward(self, x): import pandas import onnx import torch - from onnx_diagnostic.torch_onnx.sbs import run_aligned + from onnx_diagnostic.torch_onnx.sbs import ( + run_aligned, + post_process_run_aligned_obs, + ) from onnx_diagnostic.reference import OnnxruntimeEvaluator @@ -312,16 +322,9 @@ def forward(self, x): inputs = torch.load("test_doc_sbs_example.pt") - def post_process(obs): - dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs)) - dobs["err_abs"] = obs[-1]["abs"] - dobs["err_rel"] = obs[-1]["rel"] - return dobs - - results = list( map( - post_process, + post_process_run_aligned_obs, run_aligned( ep, onx, @@ -486,34 +489,11 @@ def _loop_cmp( # not a weight continue - # quick fixes - param_name = f"p_{init.name.replace('.', '_')}" - if param_name == init.name: - continue - assert param_name not in onnx_results, ( - f"Some confusion may happen because {init.name!r} -> {param_name!r} " - f"and onnx_results has {sorted(onnx_results)}" - ) - onnx_results[param_name] = onnx_results[init.name] - - param_name = f"{init.name.replace('.', '_')}".split("::")[0] - if param_name == init.name: - continue - assert param_name not in onnx_results, ( - f"Some confusion may happen because {init.name!r} -> {param_name!r} " - f"and onnx_results has {sorted(onnx_results)}" - ) - onnx_results[param_name] = onnx_results[init.name] - if verbose: print(f"[run_aligned] handles common {len(onnx_results)} initializer from torch") # we should be careful, torch may modified inplace the weights, # it may be difficult to share weights - torch_results: Dict[str, Any] = { - k: (v if use_tensor else from_numpy(v)) - for k, v in onnx_results.items() - if not k.startswith("init") - } + torch_results: Dict[str, Any] = {} if verbose: print( f"[run_aligned] handles other constant from {len(ep.graph.nodes)} nodes from torch" @@ -554,6 +534,10 @@ def _loop_cmp( if verbose: print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}") + alias_placeholder = { + **ep.state_dict, + **{f"p_{name.replace('.', '_')}": v for name, v in ep.state_dict.items()}, + } for i, node in enumerate(ep.graph.nodes): if verbose: if node.op == "call_function": @@ -573,20 +557,15 @@ def _loop_cmp( ) if verbose: t = torch_results[node.name] - print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}") + print(f"[run_aligned-ep] =plh: {node.name}={string_type(t, **str_kws)}") continue - if exc: - raise AssertionError( - f"unable to process node {node.op} -> {node.name!r}, " - f"possible candidate are " - f"{sorted(p for p in onnx_results if node.name in p)}, " - f"not in {sorted(onnx_results)}, " - f"args={string_type(args, **str_kws)}, " - f"kwargs={string_type(kwargs, **str_kws)}, " - f"onx.graph.input={[i.name for i in onx.graph.input]}" - ) - elif verbose: - print(f"[run_aligned] unable to {node.name!r} among onnx weights") + else: + assert ( + node.name in alias_placeholder + ), f"Unable to find placeholder {node.name!r} in {sorted(alias_placeholder)}" + torch_results[node.name] = alias_placeholder[node.name] + if verbose: + print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}") continue outputs = [node.name] if isinstance(node.name, str) else list(node.name) From 51d5d377816c092f827a060ea4cb42b1c729e4a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 11:01:57 +0100 Subject: [PATCH 14/22] sbs --- _unittests/ut_torch_onnx/test_sbs.py | 46 ++++++++++----- onnx_diagnostic/torch_onnx/sbs.py | 87 +++++++++++++++++++++++++--- 2 files changed, 111 insertions(+), 22 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 86a3b3ba..1ea9ff27 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -51,7 +51,7 @@ def forward(self, x): verbose=1, ), ) - self.assertEqual(len(results), 5) + self.assertEqual(len(results), 7) @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -86,7 +86,7 @@ def forward(self, x): verbose=1, ), ) - self.assertEqual(len(results), 4) + self.assertEqual(len(results), 6) @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -118,7 +118,7 @@ def forward(self, x): verbose=1, ), ) - self.assertEqual(len(results), 4) + self.assertEqual(len(results), 6) @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -149,7 +149,7 @@ def forward(self, x): verbose=11, ), ) - self.assertEqual(len(results), 5) + self.assertEqual(len(results), 7) @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -181,7 +181,7 @@ def forward(self, x): use_tensor=True, ), ) - self.assertEqual(len(results), 6) + self.assertEqual(len(results), 8) self.clean_dump() @hide_stdout() @@ -215,8 +215,7 @@ def forward(self, x): use_tensor=True, ), ) - self.assertEqual(len(results), 6) - self.assertEqual([r[-1]["dev"] for r in results], [0, 0, 0, 0, 0, 0]) + self.assertEqual(len(results), 8) @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -251,8 +250,7 @@ def forward(self, x): use_tensor=True, ), ) - self.assertEqual(len(results), 7) - self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0]) + self.assertEqual(len(results), 14) @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -286,8 +284,8 @@ def forward(self, x): use_tensor=True, ), ) - self.assertEqual(len(results), 7) - self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0]) + self.assertEqual(len(results), 14) + self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 14) @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) @@ -325,11 +323,29 @@ def forward(self, x): use_tensor=True, ), ) - pandas.DataFrame(list(map(post_process_run_aligned_obs, results))).to_excel( - self.get_dump_file("test_sbs_model_with_weights.xlsx") + df = pandas.DataFrame(list(map(post_process_run_aligned_obs, results))) + df.to_excel(self.get_dump_file("test_sbs_model_with_weights.xlsx")) + self.assertEqual( + [ + "ep_id_node", + "ep_name", + "ep_target", + "err_abs", + "err_dev", + "err_rel", + "onnx_id_node", + "onnx_name", + "onnx_op_type", + "shape_type", + ], + sorted(df.columns), + ) + self.assertEqual(len(results), 10) + self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 10) + self.assertEqual( + [-1.0, -1.0, -1.0, -10, -10, -10, -10, 0.0, 1.0, 2.0], + df["onnx_id_node"].fillna(-10).tolist(), ) - self.assertEqual(len(results), 7) - self.assertEqual([r[-1].get("dev", 0) for r in results], [0, 0, 0, 0, 0, 0, 0]) self.clean_dump() diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 97dd9b57..6bef47cf 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -162,9 +162,26 @@ def post_process_run_aligned_obs(obs: Dict[str, Any]) -> Dict[str, Union[str, fl Flattens an observations produced by function :func:`onnx_diagnostic.torch_onnx.sbs.run_aligned`. """ - dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs)) - dobs["err_abs"] = obs[-1]["abs"] - dobs["err_rel"] = obs[-1]["rel"] + dobs = dict( + zip( + [ + "ep_id_node", + "onnx_id_node", + "ep_name", + "onnx_name", + "ep_target", + "onnx_op_type", + "shape_type", + ], + obs, + ) + ) + if "abs" in obs[-1]: + dobs["err_abs"] = obs[-1]["abs"] + if "rel" in obs[-1]: + dobs["err_rel"] = obs[-1]["rel"] + if "dev" in obs[-1]: + dobs["err_dev"] = obs[-1]["dev"] return dobs @@ -209,6 +226,17 @@ def run_aligned( :param exc: stops if an exception :return: a list of tuples containing the results, they come in tuple + Each tuple is: + + - ep_id_node + - onnx_id_node + - ep_name + - onnx_name + - ep target name + - onnx op _type + - ep or onnx shape and type + - difference + Example: .. runpython:: @@ -446,7 +474,7 @@ def _loop_cmp( f"{string_diff(d, with_shape=True, with_device=True)} - " f"[{to}/{o}]" ) - return (i, i_onnx, o, to, d) + return (i, i_onnx, o, to, string_type(torch_results[to], **str_kws), d) return None if verbose: @@ -538,7 +566,20 @@ def _loop_cmp( **ep.state_dict, **{f"p_{name.replace('.', '_')}": v for name, v in ep.state_dict.items()}, } - for i, node in enumerate(ep.graph.nodes): + for n in onnx_results: + if n not in alias_placeholder: + yield ( + None, + -1, + None, + n, + None, + "initializer", + string_type(onnx_results[n], **str_kws), + {}, + ) + ep_graph_nodes = list(ep.graph.nodes) + for i, node in enumerate(ep_graph_nodes): if verbose: if node.op == "call_function": print( @@ -558,6 +599,18 @@ def _loop_cmp( if verbose: t = torch_results[node.name] print(f"[run_aligned-ep] =plh: {node.name}={string_type(t, **str_kws)}") + if node.name in alias_placeholder: + # Otherwise, it is an input. + yield ( + -1, + -1, + node.name, + node.name, + "placeholder", + "initializer", + string_type(t, **str_kws), + max_diff(alias_placeholder[node.name], onnx_results[node.name]), + ) continue else: assert ( @@ -566,6 +619,16 @@ def _loop_cmp( torch_results[node.name] = alias_placeholder[node.name] if verbose: print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}") + yield ( + -1, + None, + node.name, + None, + "placeholder", + None, + string_type(torch_results[node.name], **str_kws), + {}, + ) continue outputs = [node.name] if isinstance(node.name, str) else list(node.name) @@ -639,7 +702,12 @@ def _loop_cmp( i_onnx, ) if tmp is not None: - yield tmp + yield ( + *tmp[:4], + str(ep_graph_nodes[tmp[0]].target), + onx.graph.node[tmp[1]].op_type, + *tmp[-2:], + ) last_position = max_pos + 1 @@ -670,4 +738,9 @@ def _loop_cmp( i_onnx, ) if tmp is not None: - yield tmp + yield ( + *tmp[:4], + str(ep_graph_nodes[tmp[0]].target), + onx.graph.node[tmp[1]].op_type, + *tmp[-2:], + ) From 9b63a268aba09e91fcdb83956aabe0cfc9715e9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 11:55:11 +0100 Subject: [PATCH 15/22] fix sbs --- _unittests/ut_torch_onnx/test_sbs.py | 73 +++++++++++++++++++++++++--- onnx_diagnostic/torch_onnx/sbs.py | 48 ++++++++++-------- 2 files changed, 94 insertions(+), 27 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 1ea9ff27..46a568fc 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -289,7 +289,7 @@ def forward(self, x): @hide_stdout() @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) - def test_sbs_model_with_weights(self): + def test_sbs_model_with_weights_custom(self): torch = self.torch class Model(self.torch.nn.Module): @@ -310,7 +310,7 @@ def forward(self, x): ep = self.torch.export.export( Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) - filename = self.get_dump_file("test_sbs_model_with_weights.onnx") + filename = self.get_dump_file("test_sbs_model_with_weights_custom.onnx") to_onnx(ep, exporter="custom", filename=filename) onx = onnx.load(filename) results = list( @@ -324,7 +324,7 @@ def forward(self, x): ), ) df = pandas.DataFrame(list(map(post_process_run_aligned_obs, results))) - df.to_excel(self.get_dump_file("test_sbs_model_with_weights.xlsx")) + df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom.xlsx")) self.assertEqual( [ "ep_id_node", @@ -340,10 +340,71 @@ def forward(self, x): ], sorted(df.columns), ) - self.assertEqual(len(results), 10) - self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 10) + self.assertEqual(len(results), 12) + self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 12) self.assertEqual( - [-1.0, -1.0, -1.0, -10, -10, -10, -10, 0.0, 1.0, 2.0], + [-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0], + df["onnx_id_node"].fillna(-10).tolist(), + ) + self.clean_dump() + + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_model_with_weights_dynamo(self): + torch = self.torch + + class Model(self.torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32 + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(32, 1) # hidden → output + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + inputs = dict(x=self.torch.randn((5, 10))) + ds = dict(x={0: "batch"}) + Model()(**inputs) + ep = self.torch.export.export( + Model(), (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + filename = self.get_dump_file("test_sbs_model_with_weights_dynamo.onnx") + to_onnx(ep, exporter="onnx-dynamo", filename=filename) + onx = onnx.load(filename) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + verbose=11, + use_tensor=True, + ), + ) + df = pandas.DataFrame(list(map(post_process_run_aligned_obs, results))) + df.to_excel(self.get_dump_file("test_sbs_model_with_weights_dynamo.xlsx")) + self.assertEqual( + [ + "ep_id_node", + "ep_name", + "ep_target", + "err_abs", + "err_dev", + "err_rel", + "onnx_id_node", + "onnx_name", + "onnx_op_type", + "shape_type", + ], + sorted(df.columns), + ) + self.assertEqual(len(results), 12) + self.assertEqual([r[-1].get("dev", 0) for r in results], [0] * 12) + self.assertEqual( + [-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0], df["onnx_id_node"].fillna(-10).tolist(), ) self.clean_dump() diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 6bef47cf..4cdbce0a 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -562,12 +562,12 @@ def _loop_cmp( if verbose: print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}") - alias_placeholder = { - **ep.state_dict, - **{f"p_{name.replace('.', '_')}": v for name, v in ep.state_dict.items()}, + placeholders = {node.name for node in ep.graph.nodes if node.op == "placeholder"} + placeholders_to_state_dict = { + f"p_{name.replace('.', '_')}": name for name in ep.state_dict } for n in onnx_results: - if n not in alias_placeholder: + if n not in placeholders: yield ( None, -1, @@ -599,24 +599,30 @@ def _loop_cmp( if verbose: t = torch_results[node.name] print(f"[run_aligned-ep] =plh: {node.name}={string_type(t, **str_kws)}") - if node.name in alias_placeholder: - # Otherwise, it is an input. - yield ( - -1, - -1, - node.name, - node.name, - "placeholder", - "initializer", - string_type(t, **str_kws), - max_diff(alias_placeholder[node.name], onnx_results[node.name]), - ) - continue + # Otherwise, it is an input. + is_input = node.name in placeholders + yield ( + -1, + -1, + node.name, + node.name, + "input" if is_input else "placeholder", + "input" if is_input else "initializer", + string_type(t, **str_kws), + ( + {} + if is_input + else max_diff( + placeholders_to_state_dict[node.name], onnx_results[node.name] + ) + ), + ) else: - assert ( - node.name in alias_placeholder - ), f"Unable to find placeholder {node.name!r} in {sorted(alias_placeholder)}" - torch_results[node.name] = alias_placeholder[node.name] + assert node.name in placeholders_to_state_dict, ( + f"Unable to find placeholder {node.name!r} in " + f"{sorted(placeholders_to_state_dict)}" + ) + torch_results[node.name] = ep.state_dict[placeholders_to_state_dict[node.name]] if verbose: print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}") yield ( From ed4eaa68f844db1ff71be640b4d9f97c0e7c3eef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 11:24:34 +0000 Subject: [PATCH 16/22] fix --- onnx_diagnostic/_command_lines_parser.py | 8 +++++--- onnx_diagnostic/helpers/ort_session.py | 7 ++----- onnx_diagnostic/torch_onnx/sbs.py | 6 ++++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index b5c614bb..a40ecc55 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1233,9 +1233,11 @@ def _size(name): use_tensor=True, exc=False, ): - data.append(post_process_run_aligned_obs(obs)) - df = pandas.DataFrame(data) - df.to_excel(args.output) + pobs = post_process_run_aligned_obs(obs) + data.append(pobs) + if "initializer" not in pobs and "placeholder" not in pobs: + df = pandas.DataFrame(data) + df.to_excel(args.output) print("-- done") diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 7f90c3cc..69f41826 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -502,11 +502,8 @@ def run_dlpack( if not v.is_contiguous(): v = v.contiguous() if v.dtype == torch.bool: - # It does not work with dlpack - # unless onnxruntime updates the version it is using. - v = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type( - v.detach().numpy(), onnx.TensorProto.BOOL - ) + v = v.to(torch.uint8) + v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True) else: v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False) input_names.append(k) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 4cdbce0a..30fd1bf6 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -563,8 +563,10 @@ def _loop_cmp( print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}") placeholders = {node.name for node in ep.graph.nodes if node.op == "placeholder"} + ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers())} placeholders_to_state_dict = { - f"p_{name.replace('.', '_')}": name for name in ep.state_dict + **{f"p_{name.replace('.', '_')}": name for name in ep.state_dict}, + **{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()}, } for n in onnx_results: if n not in placeholders: @@ -622,7 +624,7 @@ def _loop_cmp( f"Unable to find placeholder {node.name!r} in " f"{sorted(placeholders_to_state_dict)}" ) - torch_results[node.name] = ep.state_dict[placeholders_to_state_dict[node.name]] + torch_results[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]] if verbose: print(f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}") yield ( From 6a0f91f58eaae290fdd41ee5c0874a8fc65baa05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 11:48:33 +0000 Subject: [PATCH 17/22] fix --- onnx_diagnostic/torch_onnx/sbs.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 30fd1bf6..b90e4060 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -108,6 +108,21 @@ def run_fx_node( return args if node.op == "call_function": assert callable(node.target), f"{node.target!r} not callable in node {node!r}" + for a, ea in zip(args, node.args): + if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta: + ta = ea.meta["val"] + # if not isinstance(ta, torch.Tensor): + # print("******", args) + # print("******", node.args) + # print("******", node.kwargs) + # print("******", node.meta) + # print(ta) + assert len(a.shape) == len(ta.shape) and a.dtype == ta.dtype, ( + f"Unable to run node {node!r}, target={node.target!r}, " + f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, " + f"args={string_type(args, with_shape=True, with_device=True)}, " + f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}" + ) try: outputs = node.target(*args, **(kwargs or {})) except RuntimeError as e: @@ -381,7 +396,7 @@ def forward(self, x): -v 1 --atol=0.1 --rtol=1 """ assert callable(run_cls), f"run_cls={run_cls} not a callable" - str_kws = dict(with_shape=True, with_device=True, with_min_max=True) + str_kws = dict(with_shape=True, with_device=True) has_cuda = any( (isinstance(t, torch.Tensor) and t.is_cuda) for t in flatten_object([args, kwargs], drop_keys=True) @@ -592,7 +607,12 @@ def _loop_cmp( print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}") if node.op == "placeholder": - if node.name in onnx_results: + is_input = node.name in placeholders + if node.name in onnx_results and ( + is_input + or ep_state_dict[placeholders_to_state_dict[node.name]].shape + == onnx_results[node.name] + ): torch_results[node.name] = ( onnx_results[node.name] if use_tensor @@ -602,7 +622,6 @@ def _loop_cmp( t = torch_results[node.name] print(f"[run_aligned-ep] =plh: {node.name}={string_type(t, **str_kws)}") # Otherwise, it is an input. - is_input = node.name in placeholders yield ( -1, -1, @@ -615,7 +634,8 @@ def _loop_cmp( {} if is_input else max_diff( - placeholders_to_state_dict[node.name], onnx_results[node.name] + ep_state_dict[placeholders_to_state_dict[node.name]], + onnx_results[node.name], ) ), ) From f0c543c4e20377d158cf95ea8f4017fa67761ee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 17:08:52 +0100 Subject: [PATCH 18/22] mypy" git push " --- onnx_diagnostic/helpers/helper.py | 10 +++++----- onnx_diagnostic/torch_onnx/sbs.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 5b57498a..ed1bd4e3 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1204,10 +1204,10 @@ def max_diff( else: for k, v in d["rep"].items(): drep[k] += v - if "dev" in d: + if "dev" in d and d["dev"] is not None: if dd is None: dd = d["dev"] - elif d["dev"] is not None: + else: dd += d["dev"] res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn) @@ -1263,7 +1263,7 @@ def max_diff( if _index < begin or (end != -1 and _index >= end): # out of boundary res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) - if dev: + if dev is not None: res["dev"] = dev return res if isinstance(expected, (int, float)): @@ -1279,7 +1279,7 @@ def max_diff( n=1, dnan=0, ) - if dev: + if dev is not None: res["dev"] = dev return res return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) @@ -1361,7 +1361,7 @@ def max_diff( res: Dict[str, float] = dict( # type: ignore abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm ) - if dev: + if dev is not None: res["dev"] = dev if hist: if isinstance(hist, bool): diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index b90e4060..5a66a069 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -172,7 +172,17 @@ def prepare_args_kwargs( return new_args, new_kwargs -def post_process_run_aligned_obs(obs: Dict[str, Any]) -> Dict[str, Union[str, float, int]]: +def post_process_run_aligned_obs( + obs: Tuple[ + Optional[int], + Optional[int], + Optional[str], + Optional[str], + Optional[str], + Optional[str], + Dict[str, Union[int, float]], + ], +) -> Dict[str, Union[str, float, int]]: """ Flattens an observations produced by function :func:`onnx_diagnostic.torch_onnx.sbs.run_aligned`. From 940bae51c20ef95d06ed222d2620af1d418a4d01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 17:03:46 +0000 Subject: [PATCH 19/22] onx fix --- onnx_diagnostic/torch_onnx/sbs.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 5a66a069..8af1b140 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -111,13 +111,11 @@ def run_fx_node( for a, ea in zip(args, node.args): if isinstance(a, torch.Tensor) and hasattr(ea, "meta") and "val" in ea.meta: ta = ea.meta["val"] - # if not isinstance(ta, torch.Tensor): - # print("******", args) - # print("******", node.args) - # print("******", node.kwargs) - # print("******", node.meta) - # print(ta) - assert len(a.shape) == len(ta.shape) and a.dtype == ta.dtype, ( + assert ( + isinstance(ta, torch.Tensor) + and len(a.shape) == len(ta.shape) + and a.dtype == ta.dtype + ), ( f"Unable to run node {node!r}, target={node.target!r}, " f"node.args={node.args!r}, node.kwargs={node.kwargs!r}, " f"args={string_type(args, with_shape=True, with_device=True)}, " @@ -672,7 +670,7 @@ def _loop_cmp( outputs = [node.name] if isinstance(node.name, str) else list(node.name) args, kwargs = prepare_args_kwargs(torch_results, node) new_outputs = run_fx_node(node, args, kwargs) - if isinstance(new_outputs, (torch.Tensor, int, float, list)): + if isinstance(new_outputs, (torch.Tensor, int, float, list, tuple)): new_outputs = (new_outputs,) if new_outputs is None: From ca35bc8d6fab9821a68ebadad30e37a235c3966b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 18:33:02 +0100 Subject: [PATCH 20/22] fix mypy --- .../test_patch_torch.py | 2 +- _unittests/ut_torch_onnx/test_sbs.py | 25 +++++++++++++++++++ onnx_diagnostic/_command_lines_parser.py | 13 +++++++++- onnx_diagnostic/helpers/helper.py | 10 ++++---- onnx_diagnostic/torch_onnx/sbs.py | 6 ++--- 5 files changed, 46 insertions(+), 10 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 74afca7d..1bcdd337 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -510,7 +510,7 @@ def _batch1(t): got = ep.module()(**torch_deepcopy(inputs)) self.assertEqualArrayAny(expected, got) - @requires_torch("2.9", "Eq(s3, Max(s10, s3)) is inconsistent!") + @requires_torch("2.11", "Eq(s3, Max(s10, s3)) is inconsistent!, until we know more") def test_patch_tiny_llm_dim_meta_level_1(self): class Model(torch.nn.Module): def forward(self, x, ind1, ind2): diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 46a568fc..6ae72e94 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -409,6 +409,31 @@ def forward(self, x): ) self.clean_dump() + @hide_stdout() + def test_sbs_unique_consecutive(self): + torch = self.torch + + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x) + + model = Model() + inputs = (torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int64),) + ds = ({0: "length"},) + ep = torch.export.export(model, inputs, dynamic_shapes=use_dyn_not_str(ds)) + onx = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom").model_proto + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + verbose=11, + use_tensor=True, + ), + ) + self.assertEqual(len(results), 5) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index a40ecc55..3f61915e 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1166,6 +1166,13 @@ def get_parser_sbs() -> ArgumentParser: required=False, help="verbosity", ) + parser.add_argument( + "-r", + "--ratio", + default=5, + required=False, + help="Saves the result in an excel file every node.", + ) return parser @@ -1220,6 +1227,7 @@ def _size(name): print(f"-- done in {time.perf_counter() - begin:1.1f}s") print("-- starts side-by-side") + ratio = int(args.ratio) data = [] for obs in run_aligned( ep, @@ -1235,9 +1243,12 @@ def _size(name): ): pobs = post_process_run_aligned_obs(obs) data.append(pobs) - if "initializer" not in pobs and "placeholder" not in pobs: + if "initializer" not in pobs and "placeholder" not in pobs and len(data) % ratio == 0: df = pandas.DataFrame(data) df.to_excel(args.output) + print(f"-- final saves into {args.output!r}") + df = pandas.DataFrame(data) + df.to_excel(args.output) print("-- done") diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index ed1bd4e3..665954f5 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -994,7 +994,7 @@ def max_diff( _index: int = 0, allow_unique_tensor_with_list_of_one_element: bool = True, hist: Optional[Union[bool, List[float]]] = None, -) -> Dict[str, Union[float, int, Tuple[int, ...]]]: +) -> Dict[str, Union[float, int, Tuple[Any, ...]]]: """ Returns the maximum discrepancy. @@ -1208,7 +1208,7 @@ def max_diff( if dd is None: dd = d["dev"] else: - dd += d["dev"] + dd += d["dev"] # type: ignore[operator] res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn) if dd is not None: @@ -1264,8 +1264,8 @@ def max_diff( # out of boundary res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) if dev is not None: - res["dev"] = dev - return res + res["dev"] = dev # type: ignore[operator] + return res # type: ignore[return-value] if isinstance(expected, (int, float)): if isinstance(got, np.ndarray) and len(got.shape) == 0: got = float(got) @@ -1281,7 +1281,7 @@ def max_diff( ) if dev is not None: res["dev"] = dev - return res + return res # type: ignore[return-value] return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if expected.dtype in (np.complex64, np.complex128): if got.dtype == expected.dtype: diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 8af1b140..0dfc3c1c 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -199,11 +199,11 @@ def post_process_run_aligned_obs( obs, ) ) - if "abs" in obs[-1]: + if "abs" in obs[-1] and obs[-1]["abs"] is not None: dobs["err_abs"] = obs[-1]["abs"] - if "rel" in obs[-1]: + if "rel" in obs[-1] and obs[-1]["rel"] is not None: dobs["err_rel"] = obs[-1]["rel"] - if "dev" in obs[-1]: + if "dev" in obs[-1] and obs[-1]["dev"] is not None: dobs["err_dev"] = obs[-1]["dev"] return dobs From fea9fa625c7227051cba5cfbc313ee7c629eeab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 15 Nov 2025 18:58:52 +0100 Subject: [PATCH 21/22] fix missing domain --- .../test_onnxruntime_evaluator.py | 30 +++++++++++++++++++ onnx_diagnostic/reference/ort_evaluator.py | 6 ++++ onnx_diagnostic/torch_onnx/sbs.py | 4 +-- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_reference/test_onnxruntime_evaluator.py b/_unittests/ut_reference/test_onnxruntime_evaluator.py index 7906d586..8454a9c4 100644 --- a/_unittests/ut_reference/test_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_onnxruntime_evaluator.py @@ -1,4 +1,5 @@ import unittest +from typing import Optional import numpy as np import onnx import onnx.helper as oh @@ -22,6 +23,13 @@ class TestOnnxruntimeEvaluator(ExtTestCase): + def _range(self, *shape, bias: Optional[float] = None): + n = np.prod(shape) + x = np.arange(n).astype(np.float32) / n + if bias: + x = x + bias + return x.reshape(tuple(shape)).astype(np.float32) + @ignore_warnings(FutureWarning) def test_ort_eval_scan_cdist_add(self): @@ -229,6 +237,28 @@ def test_report_results_comparison_ort(self): self.assertLess(d[(0, "nx"), "r_cos"], 1e-6) self.assertLess(d[(2, "u"), "r_exp"], 1e-6) + @hide_stdout() + def test_skip_layer_normalization(self): + node = oh.make_node( + "SkipLayerNormalization", + ["x", "skip", "beta", "gamma", "bias"], + ["Z"], + epsilon=1.0e-5, + domain="com.microsoft", + ) + feeds = dict( + x=self._range(2, 3, 8), + skip=self._range(2, 3, 8, bias=3), + beta=self._range(8, bias=1), + gamma=self._range(8, bias=2), + bias=self._range(8, bias=0.1), + ) + ref = ExtendedReferenceEvaluator(node) + expected = ref.run(None, feeds) + rt = OnnxruntimeEvaluator(node, verbose=10, opsets={"": 22}) + got = rt.run(None, feeds) + self.assertEqualAny(expected, got, atol=1e-4) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index ed8b5f34..8ac90321 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -373,6 +373,12 @@ def _make_model_proto( ) else: onx.opset_import.append(oh.make_opsetid("", onnx_opset_version())) + opsets = {d.domain: d.version for d in onx.opset_import} + add = {} + for node in nodes: + if node.domain and node.domain not in opsets and node.domain not in add: + add[node.domain] = 1 + onx.opset_import.extend([oh.make_opsetid(k, v) for k, v in add.items()]) # That helps fixing bugs. onx = shi.infer_shapes(onx) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 0dfc3c1c..42336f2f 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -178,9 +178,9 @@ def post_process_run_aligned_obs( Optional[str], Optional[str], Optional[str], - Dict[str, Union[int, float]], + Dict[str, Optional[Union[int, float]]], ], -) -> Dict[str, Union[str, float, int]]: +) -> Dict[str, Optional[Union[str, float, int]]]: """ Flattens an observations produced by function :func:`onnx_diagnostic.torch_onnx.sbs.run_aligned`. From 99d3a52d4639002590a5900c1d69f777794d090c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 16 Nov 2025 20:58:38 +0100 Subject: [PATCH 22/22] mypy --- onnx_diagnostic/torch_onnx/sbs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 42336f2f..6b4f2359 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -200,12 +200,12 @@ def post_process_run_aligned_obs( ) ) if "abs" in obs[-1] and obs[-1]["abs"] is not None: - dobs["err_abs"] = obs[-1]["abs"] + dobs["err_abs"] = obs[-1]["abs"] # type: ignore[assignment] if "rel" in obs[-1] and obs[-1]["rel"] is not None: - dobs["err_rel"] = obs[-1]["rel"] + dobs["err_rel"] = obs[-1]["rel"] # type: ignore[assignment] if "dev" in obs[-1] and obs[-1]["dev"] is not None: - dobs["err_dev"] = obs[-1]["dev"] - return dobs + dobs["err_dev"] = obs[-1]["dev"] # type: ignore[assignment] + return dobs # type: ignore[return-value] def run_aligned(