diff --git a/.gitignore b/.gitignore index 900e2ae8..0cc02aeb 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,7 @@ _cache/* .coverage dist/* build/* +_sbs_* .eggs/* .olive-cache/* .hypothesis/* diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index f8c85340..bce920d9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -8,7 +8,7 @@ Change Logs * :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime * :pr:`310`: splits patches into multiple files * :pr:`308`: add option --save_ep to dump the exported program as well as torch input -* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`: improves side-by-side comparison, creates command line sbs +* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`, :pr:`319`: improves side-by-side comparison, creates command line sbs 0.8.2 +++++ diff --git a/_unittests/ut_helpers/test_torch_helper.py b/_unittests/ut_helpers/test_torch_helper.py index 9441e425..2c3c2990 100644 --- a/_unittests/ut_helpers/test_torch_helper.py +++ b/_unittests/ut_helpers/test_torch_helper.py @@ -28,7 +28,7 @@ ) from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended -from onnx_diagnostic.helpers.torch_helper import to_tensor +from onnx_diagnostic.helpers.torch_helper import to_tensor, study_discrepancies TFLOAT = onnx.TensorProto.FLOAT @@ -425,6 +425,12 @@ def test_get_weight_type(self): dt = get_weight_type(model) self.assertEqual(torch.float32, dt) + def test_study_discrepancies(self): + t1 = torch.rand((3, 4)) + t2 = torch.rand((3, 4)) + ax = study_discrepancies(t1, t2) + self.assertEqual(ax.shape, ((3, 2))) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index c67fc770..182dcdd5 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1234,7 +1234,7 @@ def get_parser_sbs() -> ArgumentParser: "--replay-threshold", type=float, required=False, - default=1e6, + default=1e9, help="Triggers the replay if the discrepancies are higher than this value.", ) parser.add_argument( diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index ec220b29..b0d5cae6 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -1,6 +1,7 @@ import contextlib import ctypes import inspect +import math import os import sys import warnings @@ -1003,3 +1004,76 @@ def get_weight_type(model: torch.nn.Module) -> torch.dtype: counts[dt] += 1 final = max(list(counts.items())) return final[0] + + +def closest_factor_pair(n: int): + """Tries to find ``a, b`` such as ``n == a * b``.""" + assert n > 0, f"n={n} must be a positive integer" + start = math.isqrt(n) + for a in range(start, 0, -1): + if n % a == 0: + b = n // a + return a, b + return 1, n + + +def study_discrepancies( + t1: torch.Tensor, + t2: torch.Tensor, + bins: int = 50, + figsize: Optional[Tuple[int, int]] = (15, 15), + title: Optional[str] = None, + name: Optional[str] = None, +) -> "matplotlib.axes.Axes": # noqa: F821 + """ + Computes different metrics for the discrepancies. + Returns graphs. + """ + assert t1.dtype == t2.dtype, f"Type mismatch {t1.dtype} != {t2.dtype}" + assert t1.shape == t2.shape, f"Shape mismatch {t1.shape} != {t2.shape}" + d1, d2 = ( + (t1, t2) if t1.dtype == torch.float64 else (t1.to(torch.float32), t2.to(torch.float32)) + ) + + d1 = d1.squeeze() + d2 = d2.squeeze() + if len(d1.shape) == 1: + new_shape = closest_factor_pair(d1.shape[0]) + d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape) + elif len(d1.shape) > 2: + new_shape = (-1, max(d1.shape)) + d1, d2 = d1.reshape(new_shape), d2.reshape(new_shape) + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(3, 2, figsize=figsize) + vmin, vmax = d1.min().item(), d1.max().item() + ax[0, 0].imshow(d1.detach().cpu().numpy(), cmap="Greys", vmin=vmin, vmax=vmax) + ax[0, 0].set_title( + f"Color plot of the first tensor in\n[{vmin}, {vmax}]\n{t1.shape} -> {d1.shape}" + ) + + diff = d2 - d1 + vmin, vmax = diff.min().item(), diff.max().item() + ax[0, 1].imshow(diff.detach().cpu().numpy(), cmap="seismic", vmin=vmin, vmax=vmax) + ax[0, 1].set_title(f"Color plot of the differences in \n[{vmin}, {vmax}]") + + ax[1, 0].hist(d1.detach().cpu().numpy().ravel(), bins=bins) + ax[1, 0].set_title("Distribution of the first tensor") + + ax[1, 1].hist(diff.detach().cpu().numpy().ravel(), bins=bins) + ax[1, 1].set_title("Distribution of the differences") + + tf1 = d1.ravel() + td1 = diff.ravel() + ax[2, 1].plot(tf1.detach().cpu().numpy(), td1.detach().cpu().numpy(), ".") + ax[2, 1].set_title("Graph XY") + ax[2, 1].set_xlabel("First tensor values") + ax[2, 1].set_ylabel("Difference values") + + if title: + fig.suptitle(title) + fig.tight_layout() + if name: + fig.savefig(name) + return ax diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index a2f814bf..23586421 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -1,14 +1,26 @@ import inspect import os +import textwrap import time from dataclasses import dataclass from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union + +try: + from typing import Self +except ImportError: + # python <= 3.10 + Self = "Self" # type: ignore[assignment] import onnx import onnx.helper as oh import numpy as np import torch from ..helpers import string_type, string_diff, max_diff, flatten_object -from ..helpers.onnx_helper import pretty_onnx, extract_subset_of_nodes, make_submodel +from ..helpers.onnx_helper import ( + pretty_onnx, + extract_subset_of_nodes, + make_submodel, + from_array_extended, +) from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype @@ -214,16 +226,124 @@ def select( :param err_abs: measured discrepancy :return: True if this should be dumped """ - if name and self.selected_names: - if name in self.selected_names: - return True - if op_type and self.selected_op_types: - if op_type in self.selected_op_types: - return True - if err_abs is not None and err_abs >= self.threshold: + if name and self.selected_names and name in self.selected_names: + return True + if op_type and self.selected_op_types and op_type in self.selected_op_types: + return True + if err_abs is not None and self.threshold is not None and err_abs >= self.threshold: return True return False + def get_replay_code(self) -> str: + """ + Returns a code letting the user replay the onnx model. + It looks like the following. It may have to be adapted. + + .. runpython:: + :showcode: + + from onnx_diagnostic.torch_onnx.sbs import ReplayConfiguration + + rc = ReplayConfiguration(dump_folder="unused") + print(rc.get_replay_code()) + """ + return textwrap.dedent( + """ + import onnx + import torch + from onnx_diagnostic.helpers import max_diff, string_diff, string_type + from onnx_diagnostic.helpers.torch_helper import study_discrepancies + from onnx_diagnostic.helpers.onnx_helper import pretty_onnx + from onnx_diagnostic.reference import OnnxruntimeEvaluator + + skws = dict(with_shape=True, with_device=True) + + torch_inputs = torch.load("torch_inputs.pt") + onnx_inputs = torch.load("onnx_inputs.pt") + expected_outputs_and_mapping = torch.load("torch_outputs_and_mapping.pt") + expected = expected_outputs_and_mapping["expected"] + mapping = expected_outputs_and_mapping["mapping"] + + print(f"-- torch_inputs={string_type(torch_inputs, **skws)}") + print(f"-- onnx_inputs={string_type(onnx_inputs, **skws)}") + print(f"-- expected={string_type(expected, **skws)}") + print(f"-- mapping={mapping}") + + print() + print("-- model.onnx") + print() + + model = onnx.load("model.onnx") + print(pretty_onnx(model)) + + print() + print("-- range of inputs --") + print() + + for k, v in onnx_inputs.items(): + print(f"-- {k}: {string_type(v, **skws, with_min_max=True)}") + + print() + print("-- discrepancies of inputs --") + print() + + ep_feeds = {} + for k, v in onnx_inputs.items(): + tk = mapping.get(k, k) + tkv = torch_inputs[k] if k in torch_inputs else torch_inputs[tk] + ep_feeds[k] = tkv + diff = max_diff(v, tkv) + print( + f"-- {k} -> {tk} ep:{string_type(tkv, **skws)} " + f"nx:{string_type(v, **skws)} / diff {string_diff(diff)}" + ) + + print() + print("-- SVD --") + print() + + for k, v in onnx_inputs.items(): + if len(v.shape) == 2: + U, S, Vt = torch.linalg.svd(v.to(torch.float32)) + print(f" -- {k}: {S[:5]}") + + print() + print("-- run with onnx_inputs --") + print() + + sess = OnnxruntimeEvaluator(model, whole=True) + feeds = onnx_inputs + obtained = sess.run(None, feeds) + print(f"-- obtained={string_type(obtained, **skws)}") + diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01]) + print(f"-- diff: {string_diff(diff)}") + print() + print("-- plots --") + + for i in range(len(expected)): + study_discrepancies( + expected[i], + obtained[i], + title=f"study output {i}", + name=f"disc{i}.png", + bins=50, + ) + + print() + print("-- run with torch_inputs --") + print() + + obtained = sess.run(None, ep_feeds) + print(f"-- obtained={string_type(obtained, **skws)}") + diff = max_diff(expected, tuple(obtained), hist=[0.1, 0.01]) + print(f"-- diff: {string_diff(diff)}") + + print() + print("-- end --") + print() + """ + ) + def dump( self, name: str, @@ -280,21 +400,44 @@ def dump( os.makedirs(folder, exist_ok=True) if verbose: print(f"[ReplayConfiguration.dump] dumps into folder {folder!r}") - onnx.save(submodel, os.path.join(folder, "model.onnx")) + torch_inputs = {} + removed_inputs = set() for n in input_names: if n in onnx_name_to_ep_name: torch_inputs[n] = torch_results[onnx_name_to_ep_name[n]] else: - raise AssertionError(f"n={n!r}, onnx_name_to_ep_name={onnx_name_to_ep_name}") + # We add that input as an initializer because it is probably a constant. + submodel.graph.initializer.append(from_array_extended(onnx_results[n], name=n)) + removed_inputs.add(n) + + if removed_inputs: + input_names = [i for i in input_names if i not in removed_inputs] + new_inputs = [i for i in submodel.graph.input if i.name not in removed_inputs] + del submodel.graph.input[:] + submodel.graph.input.extend(new_inputs) + if verbose: + print(f"[ReplayConfiguration.dump] removed input {removed_inputs}") + print(f"[ReplayConfiguration.dump] final model inputs {input_names}") + + onnx.save(submodel, os.path.join(folder, "model.onnx")) onnx_inputs = {n: onnx_results[n] for n in input_names} assert ( name in onnx_name_to_ep_name ), f"Unable to find {name!r} in {onnx_name_to_ep_name}" - expected_outputs = (torch_results[onnx_name_to_ep_name[name]],) + expected_outputs_and_mapping = dict( + expected=(torch_results[onnx_name_to_ep_name[name]],), + mapping={ + k: onnx_name_to_ep_name[k] for k in input_names if k in onnx_name_to_ep_name + }, + ) torch.save(torch_inputs, os.path.join(folder, "torch_inputs.pt")) torch.save(onnx_inputs, os.path.join(folder, "onnx_inputs.pt")) - torch.save(expected_outputs, os.path.join(folder, "torch_outputs.pt")) + torch.save( + expected_outputs_and_mapping, os.path.join(folder, "torch_outputs_and_mapping.pt") + ) + with open(os.path.join(folder, "replay.py"), "w") as f: + f.write(self.get_replay_code()) if verbose: print(f"[ReplayConfiguration.dump] done {folder!r}") return folder @@ -357,7 +500,7 @@ def __post_init__(self): f"ep_id_node={self.ep_id_node}" ) - def set_diff(self, diff: Dict[str, Any]): + def set_diff(self, diff: Dict[str, Any]) -> Self: """Sets error.""" if diff is None: return @@ -371,6 +514,40 @@ def set_diff(self, diff: Dict[str, Any]): self.err_nan = diff["nan"] if "rep" in diff: self.err_h01 = diff["rep"][">0.1"] + return self + + @property + def key( + self, + ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]]: + "Creates a unique identifier." + return ( + self.ep_id_node, + self.onnx_id_node, + self.onnx_id_output, + self.ep_name, + self.onnx_name, + ) + + def check( + self, + already_yielded: Dict[ + Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]], + int, + ], + ) -> Self: + "Checks a record was not already yielded." + if self.onnx_op_type == "reset": + # no record for this one + return self + key = self.key + assert key not in already_yielded, ( + f"Record with key={key} was already yielded, " + f"number of records={len(already_yielded)} and previous " + f"record at position {already_yielded[key]} (self={self})" + ) + already_yielded[key] = len(already_yielded) + return self @dataclass @@ -451,8 +628,6 @@ def run_aligned( for the onnx runtime :param atol: absolute tolerance :param rtol: relative tolerance - :param gemmlinear: if True, replaces ``Gemm(A,X.T,B)`` by - ``torch.nn.functional.linear(A,X,B)`` on onnx side :param verbose: verbosity level :param exc: stops if an exception :param reset_names: list of names, the onnx execution takes the torch outputs instead @@ -595,6 +770,7 @@ def forward(self, x): -v 1 --atol=0.1 --rtol=1 """ assert callable(run_cls), f"run_cls={run_cls} not a callable" + already_yielded = {} # type: ignore[var-annotated] reset_names = set(reset_names) if reset_names else set() # type: ignore[assignment] str_kws = dict(with_shape=True, with_device=True) has_cuda = any( @@ -625,6 +801,8 @@ def forward(self, x): if verbose: print(f"[run_aligned] run_cls={run_cls}") print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}") + if replay_configuration: + print(f"[run_aligned] replay={replay_configuration}") def _check_tensor_(name, obj, flip_type=False): if flip_type: @@ -694,7 +872,7 @@ def _loop_cmp( onnx_shape_type=string_type(r, **str_kws), ) r.set_diff(d) - mapping_onnx_to_torch[to] = o + mapping_onnx_to_torch[o] = to return r return None @@ -774,7 +952,7 @@ def _loop_onnx_node( list_node_output = list(node.output) node_output = [o for o in list_node_output if o] for o, r in zip(node_output, res): - if r is None or o is None: + if r is None or not o: continue tmp = _loop_cmp( mapping_onnx_to_torch, @@ -839,13 +1017,14 @@ def _loop_onnx_node( torch_results, onnx_results, o, - r, + torch_results[tmp.ep_name], verbose, atol, rtol, i, i_onnx, ) + assert tmp.err_abs == 0, f"Reset did not happen, tmp={tmp}" if tmp is not None: tmp.onnx_op_type = "reset" tmp.onnx_id_output = list_node_output.index(o) @@ -880,7 +1059,8 @@ def _duplicated_values(d): if verbose: print(f"[run_aligned] ep: walks through {len(ep.graph.nodes)} nodes from torch") - positions: Dict[str, Any] = {} + # dictionary mapping result names and their position in both graphs. + positions: Dict[str, Dict[str, int]] = {} ep_graph_nodes = list(ep.graph.nodes) torch_results: Dict[str, Any] = {} last_position = 0 @@ -932,6 +1112,18 @@ def _duplicated_values(d): print(f"[run_aligned] ep: found inputs {torch_input_names}") print(f"[run_aligned] ep: found outputs {torch_output_names}") print(f"[run_aligned] nx: walks through {len(onx.graph.node)} nodes from onnx") + for inp in onx.graph.input: + n = inp.name + if n in positions: + positions[n]["onnx"] = -1 + else: + positions[n] = dict(onnx=-1) + for inp in onx.graph.initializer: + n = inp.name + if n in positions: + positions[n]["onnx"] = -1 + else: + positions[n] = dict(onnx=-1) for i, node in enumerate(onx.graph.node): for n in node.output: if n in positions: @@ -1001,7 +1193,6 @@ def _duplicated_values(d): memory_cpu = 0 memory_cuda = 0 for init in onx.graph.initializer: # type: ignore - positions[init.name] = -1 t = None if init.name in torch_results: if init.name not in skip_mapping_torch_onnx: @@ -1033,7 +1224,7 @@ def _duplicated_values(d): onnx_name=init.name, onnx_op_type="initializer", onnx_shape_type=string_type(t, **str_kws), - ) + ).check(already_yielded) size = t.element_size() * t.numel() if t.is_cuda: @@ -1060,6 +1251,8 @@ def _duplicated_values(d): print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}") # starts the side-by-side + if verbose: + print(f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} nodes") if verbose == 1: import tqdm @@ -1067,8 +1260,6 @@ def _duplicated_values(d): else: loop = list(enumerate(ep_graph_nodes)) - if verbose: - print(f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} nodes") already_run: Set[int] = set() ep_durations = {} status = StatusRunAligned() @@ -1115,7 +1306,7 @@ def _duplicated_values(d): onnx_results[torch_names_to_onnx_names[node.name]], **str_kws ), ) - yield record + yield record.check(already_yielded) else: assert node.name in placeholders_to_state_dict, ( f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), " @@ -1155,7 +1346,7 @@ def _duplicated_values(d): hist=[0.1], ) ) - yield record + yield record.check(already_yielded) else: if verbose > 1: print( @@ -1166,7 +1357,7 @@ def _duplicated_values(d): ep_name=node.name, ep_target="placeholder", ep_shape_type=string_type(t, **str_kws), - ) + ).check(already_yielded) continue outputs = [node.name] if isinstance(node.name, str) else list(node.name) @@ -1190,13 +1381,33 @@ def _duplicated_values(d): max_pos = -2 for n in outputs: - if n in positions and "onnx" in positions[n]: - max_pos = max(max_pos, positions[n]["onnx"]) + if n in positions: + if "onnx" in positions[n]: + max_pos = max(max_pos, positions[n]["onnx"]) + if "fx" in positions[n]: + if positions[n]["fx"] > i: + max_pos = -2 + break if max_pos == -2: # we skip. continue + next_to_visit = last_position for i_onnx in range(last_position, max_pos + 1): + if i_onnx in already_run: + continue + # The onnx node may produce more than one output, in that + # case, we need to check the exported program is not behind. + node = onx.graph.node[i_onnx] + ep_behind = False + for iname in node.output: + if iname in positions and "fx" in positions[iname]: + if positions[iname]["fx"] > i: + ep_behind = True + break + if ep_behind: + break + for r in _loop_onnx_node( onx, ep_graph_nodes, @@ -1216,9 +1427,10 @@ def _duplicated_values(d): verbose, ): if r: - yield r + yield r.check(already_yielded) + next_to_visit = i_onnx + 1 - last_position = max_pos + 1 + last_position = next_to_visit # complete the execution of the onnx graph if verbose: @@ -1227,6 +1439,8 @@ def _duplicated_values(d): f"to {len(onx.graph.node)}" ) for i_onnx in range(last_position, len(onx.graph.node)): + if i_onnx in already_run: + continue for r in _loop_onnx_node( onx, ep_graph_nodes, @@ -1246,9 +1460,7 @@ def _duplicated_values(d): verbose, ): if r: - yield r - - already_run.add(i_onnx) + yield r.check(already_yielded) if verbose: print(f"[run_aligned] done with status={status.to_str()}") diff --git a/pyproject.toml b/pyproject.toml index 5efe5e6d..f7ce6af7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,10 @@ disable_error_code = ["union-attr"] module = ["onnx_diagnostic.helpers.rt_helper"] disable_error_code = ["arg-type", "assignment", "attr-defined", "call-overload", "misc", "name-defined", "union-attr", "name-defined"] +[[tool.mypy.overrides]] +module = ["onnx_diagnostic.helpers.torch_helper"] +disable_error_code = ["arg-type", "assignment", "attr-defined", "call-overload", "misc", "name-defined", "union-attr", "name-defined"] + [[tool.mypy.overrides]] module = ["onnx_diagnostic.reference.report_results_comparison"] disable_error_code = ["name-defined"]