diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 30fd0e48..dba062c7 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.4 +++++ +* :pr:`338`: fixes ReplayConfiguration.dump, add function to select of part of a model * :pr:`337`: fixes extract_subset_of_nodes * :pr:`336`: implements versioned onnx plugs diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 9a9242e8..9b2cebea 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -7,8 +7,14 @@ import onnx.numpy_helper as onh from onnx import TensorProto, FunctionProto, ValueInfoProto from onnx.checker import check_model +from onnx.external_data_helper import ( + load_external_data_for_model, + _get_all_tensors, + uses_external_data, +) import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout +from onnx_diagnostic.reference import ExtendedReferenceEvaluator from onnx_diagnostic.helpers.onnx_helper import ( onnx_lighten, onnx_unlighten, @@ -23,6 +29,7 @@ onnx_dtype_name, extract_subset_of_nodes, make_submodel, + select_model_inputs_outputs, ) @@ -570,6 +577,97 @@ def test_extract_subset_of_nodes_bigger(self): [n.op_type for n in nodes], ) + def _get_model_select(self): + X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Z = oh.make_tensor_value_info("Z", TensorProto.INT64, [None, None]) + graph = oh.make_graph( + [ + oh.make_node("Mul", ["X", "X"], ["X2"]), + oh.make_node("Add", ["X2", "Y"], ["z1"]), + oh.make_node("Mul", ["z1", "W"], ["z2"]), + oh.make_node("Cast", ["z2"], ["Z"], to=TensorProto.INT64), + ], + "add", + [X], + [Z], + [ + onh.from_array(np.arange(16).reshape((-1, 4)).astype(np.float32), name="Y"), + onh.from_array( + (np.arange(16).reshape((-1, 4)) + 100).astype(np.float32), name="W" + ), + ], + ) + onnx_model = oh.make_model( + graph, opset_imports=[oh.make_opsetid("", 18)], ir_version=8 + ) + return onnx_model + + def test_select_model_inputs_outputs(self): + def enumerate_model_tensors(model): + for tensor in _get_all_tensors(model): + yield tensor, uses_external_data(tensor) + + model = self._get_model_select() + root = self.get_dump_folder("test_select_model_inputs_outputs") + name = os.path.join(root, "model_ext.onnx") + location = os.path.basename(name) + ".data" + onnx.save( + model, name, save_as_external_data=True, size_threshold=15, location=location + ) + self.assertEqual( + list(sorted(os.listdir(root))), + ["model_ext.onnx", "model_ext.onnx.data"], + ) + + # X + name2 = os.path.join(root, "sub_model_ext.onnx") + model2 = onnx.load(name, load_external_data=False) + new_model = select_model_inputs_outputs(model2, outputs=["X2"]) + onnx.save(new_model, name2) + + x = np.arange(16).reshape((-1, 4)).astype(np.float32) + y = np.arange(16).reshape((-1, 4)).astype(np.float32) + + sess = ExtendedReferenceEvaluator(new_model) + got = sess.run(None, {"X": x})[0] + self.assertEqual((x**2).tolist(), got.tolist()) + + sess = ExtendedReferenceEvaluator(name2) + got = sess.run(None, {"X": x})[0] + self.assertEqual((x**2).tolist(), got.tolist()) + + # z1 + name3 = os.path.join(root, "sub_model_ext_z1.onnx") + model2 = onnx.load(name, load_external_data=False) + new_model = select_model_inputs_outputs(model2, outputs=["z1"]) + onnx.save(new_model, name3) + self.assertEqual( + [ + "model_ext.onnx", + "model_ext.onnx.data", + "sub_model_ext.onnx", + "sub_model_ext_z1.onnx", + ], + list(sorted(os.listdir(root))), + ) + + x = np.arange(16).reshape((-1, 4)).astype(np.float32) + + sess = ExtendedReferenceEvaluator(name3) + got = sess.run(None, {"X": x})[0] + self.assertEqual((x**2 + y).tolist(), got.tolist()) + + tensors = list(enumerate_model_tensors(new_model)) + self.assertEqual(len(tensors), 1) + self.assertIsInstance(tensors[0], tuple) + self.assertEqual(len(tensors[0]), 2) + self.assertTrue(tensors[0][-1]) + self.assertIsInstance(tensors[0][0], TensorProto) + load_external_data_for_model(new_model, root) + sess = ExtendedReferenceEvaluator(new_model) + got = sess.run(None, {"X": x})[0] + self.assertEqual((x**2 + y).tolist(), got.tolist()) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_onnx/data/attention_loopa24.onnx b/_unittests/ut_torch_onnx/data/attention_loopa24.onnx new file mode 100644 index 00000000..a904be2a Binary files /dev/null and b/_unittests/ut_torch_onnx/data/attention_loopa24.onnx differ diff --git a/_unittests/ut_torch_onnx/data/attention_loopmha.onnx b/_unittests/ut_torch_onnx/data/attention_loopmha.onnx new file mode 100644 index 00000000..cba6a71e Binary files /dev/null and b/_unittests/ut_torch_onnx/data/attention_loopmha.onnx differ diff --git a/_unittests/ut_torch_onnx/test_discrepancies.py b/_unittests/ut_torch_onnx/test_discrepancies.py new file mode 100644 index 00000000..9d344ff4 --- /dev/null +++ b/_unittests/ut_torch_onnx/test_discrepancies.py @@ -0,0 +1,176 @@ +import os +import unittest +import numpy as np +import onnx +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, has_onnxruntime +from onnx_diagnostic.reference import OnnxruntimeEvaluator +from onnx_diagnostic.helpers import max_diff, string_diff + + +class TestDiscrepancies(ExtTestCase): + @ignore_warnings(DeprecationWarning) + def test_attention_opset15_in_a_loop(self): + import torch + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_attention import ( # noqa: E501 + patched_sdpa_attention_forward, + ) + + def qwen_sdpa_attention( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cu_seqlens: torch.Tensor, + scaling: float = 0, + num_heads: int = 16, + ) -> torch.Tensor: + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) + for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + patched_sdpa_attention_forward( + None, + q, + k, + v, + attention_mask=None, + scaling=scaling, + dropout=0.0, + is_causal=False, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + return attn_output + + for model_name in ["attention_loopa24.onnx", "attention_loopmha.onnx"]: + if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.24"): + # not available + continue + with self.subTest(model=model_name): + model = onnx.load(os.path.join(os.path.dirname(__file__), "data", model_name)) + sess = self.check_ort(model) + + feeds = dict( + c_lifted_tensor_0=np.array([0], dtype=np.int64), + cat_2=np.array( + [ + 0, + 64, + 128, + 192, + 256, + 304, + 368, + 432, + 496, + 560, + 608, + 672, + 736, + 800, + 864, + 912, + 976, + 1040, + 1104, + 1168, + 1216, + 1232, + 1248, + 1264, + 1280, + 1292, + ], + dtype=np.int64, + ), + unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32), + unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32), + unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32), + ) + + dummy_inputs = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "dump_test", + "replay", + "qwen_sdpa_attention_loopmha", + "onnx_inputs.pt", + ) + if os.path.exists(dummy_inputs): + print("-- use dummy inputs") + + feeds1 = torch.load(dummy_inputs) + res1 = qwen_sdpa_attention( + feeds1["unsqueeze_4"], + feeds1["unsqueeze_5"], + feeds1["unsqueeze_6"], + feeds1["cat_2"], + scaling=0.11180339753627777, + num_heads=16, + ) + feeds1o = {k: v.detach().cpu().numpy() for k, v in feeds1.items()} + reso1 = sess.run(None, feeds1o)[0] + dummy_inputs2 = dummy_inputs.replace("onnx_inputs", "torch_inputs") + assert dummy_inputs != dummy_inputs2 + feeds2 = torch.load(dummy_inputs2) + res2 = qwen_sdpa_attention( + feeds2["unsqueeze_4"], + feeds2["unsqueeze_5"], + feeds2["unsqueeze_6"], + feeds2["cat_2"], + scaling=0.11180339753627777, + num_heads=16, + ) + feeds2o = {k: v.detach().cpu().numpy() for k, v in feeds2.items()} + reso2 = sess.run(None, feeds2o)[0] + diff = max_diff(res1, res2, hist=[0.1]) + print(f"-- diff torch-onnx: {string_diff(diff)}") + diff = max_diff(res2, reso2, hist=[0.1]) + print(f"-- diff torch-onnxo1: {string_diff(diff)}") + diff = max_diff(res1, reso1, hist=[0.1]) + print(f"-- diff torch-onnxo2: {string_diff(diff)}") + if diff["abs"] > 0.1: + for k in feeds1: + print( + f"-- {k}: " + f"{string_diff(max_diff(feeds1[k], feeds2[k], hist=[0.1]))}" + ) + + feeds = { + k: v.detach().cpu().numpy() + for k, v in torch.load(dummy_inputs).items() + } + + for k, v in feeds.items(): + print( + f"-- {k}: " + f"{self.string_type(v, with_shape=True, with_min_max=True)}" + ) + + # feeds["cat_2"] = np.array([0, 1292], dtype=np.int64) + got = sess.run(None, feeds) + self.assertEqual(len(got), 1) + self.assertEqual((1, 1292, 16, 80), got[0].shape) + expected = qwen_sdpa_attention( + torch.from_numpy(feeds["unsqueeze_4"]), + torch.from_numpy(feeds["unsqueeze_5"]), + torch.from_numpy(feeds["unsqueeze_6"]), + torch.from_numpy(feeds["cat_2"]), + scaling=0.11180339753627777, + num_heads=16, + ) + self.assertEqualArray(expected, got[0], atol=1e-5) + + tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()} + ev = OnnxruntimeEvaluator(model) + got2 = ev.run(None, tfeeds) + self.assertEqual(len(got2), 1) + self.assertEqualArray(got[0], got2[0], atol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 519e5dc2..ae885966 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1342,6 +1342,20 @@ def get_parser_sbs() -> ArgumentParser: default="replay", help="If the replay is triggered, this defines the folder where everything is dumped.", ) + parser.add_argument( + "-p", + "--replay-prefix-model", + action=BooleanOptionalAction, + default=False, + help=textwrap.dedent( + """ + There are two ways to recompute an intermediate output, the first one is to " + produce the minimal model between torch and onnx. + The second one is to dump onnx models from the inputs + to the considered intermediate results. This enables the second one. + """ + ), + ) return parser @@ -1431,6 +1445,7 @@ def _size(name): set(args.replay_op_types.split(",")) if args.replay_op_types else None ), dump_folder=args.replay_folder, + dump_prefix_model=args.replay_prefix_model, ) print("-- starts side-by-side") diff --git a/onnx_diagnostic/export/control_flow_research.py b/onnx_diagnostic/export/control_flow_research.py index 261d0a5a..c10d9b07 100644 --- a/onnx_diagnostic/export/control_flow_research.py +++ b/onnx_diagnostic/export/control_flow_research.py @@ -92,10 +92,11 @@ def _loop_for_op_wrapper(*args, **kwargs): from torch._higher_order_ops.utils import setup_compilation_env - with setup_compilation_env() as backend: - return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)( - n_iter, body_fn, operands - ) + with setup_compilation_env() as _backend: + return _loop_for_op_wrapper(n_iter, body_fn, *operands) + # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)( + # n_iter, body_fn, operands + # ) def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands): @@ -127,9 +128,13 @@ def loop_for_op_dense(n_iter, body_fn, operands): ), f"Dense implementation operands must be a list of tensors and ints {operands}" mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" - return _loop_for_onnx_fn(body_fn, n_iter, None, *operands) + return _loop_for_onnx_fn(body_fn, n_iter, None, operands) @simple_loop_for_op.py_impl(ProxyTorchDispatchMode) def inner(mode, n_iter, body_fn, operands): return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands) + + +simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU) +simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA) diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index 360e3910..09d16f29 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -1,29 +1,9 @@ -from typing import Dict, Set +from typing import Dict import numpy as np import onnx import onnx.numpy_helper as onh from ..reference import ExtendedReferenceEvaluator as Inference -from .onnx_helper import onnx_dtype_name, pretty_onnx - - -def _get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: - hidden = set() - memo = ( - {i.name for i in graph.initializer} - | {i.values.name for i in graph.sparse_initializer} - | {i.name for i in graph.input} - ) - for node in graph.node: - for i in node.input: - if i not in memo: - hidden.add(i) - for att in node.attribute: - if att.type == onnx.AttributeProto.GRAPH and att.g: - hid = _get_hidden_inputs(att.g) - less = set(h for h in hid if h not in memo) - hidden |= less - memo |= set(node.output) - return hidden +from .onnx_helper import onnx_dtype_name, pretty_onnx, get_hidden_inputs def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str: @@ -221,7 +201,7 @@ def _mkn(obj: object) -> int: unique = set() for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: - unique |= _get_hidden_inputs(att.g) + unique |= get_hidden_inputs(att.g) for i in unique: edge = name_to_ids[i], _mkn(node) # type: ignore[assignment] if edge in done: diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 85713459..68fcf599 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -3,7 +3,19 @@ import os import sys import warnings -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import numpy as np import numpy.typing as npt import onnx @@ -1198,6 +1210,43 @@ def shadowing_names( return shadow, post_shadow, created +def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: + """ + Returns the hidden inputs (inputs coming from an upper context) + used by a subgraph. It excludes empty names. + """ + hidden = set() + memo = ( + set(i.name for i in graph.initializer) + | set(i.name for i in graph.sparse_initializer) + | set(i.name for i in graph.input) + ) + for node in graph.node: + for i in node.input: + if i and i not in memo: + hidden.add(i) + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH and att.g: + hid = get_hidden_inputs(att.g) + less = set(h for h in hid if h not in memo) + hidden |= less + memo |= set(node.output) + return hidden + + +def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]: + """ + Returns input and hidden inputs of a node. + See :func:`get_hidden_inputs`. It excludes empty names. + """ + start = {i for i in node.input if i} + if node.op_type in {"Scan", "Loop", "If"}: + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + start |= get_hidden_inputs(att.g) + return start + + def extract_subset_of_nodes( model: ModelProto, name: str, @@ -1240,14 +1289,28 @@ def extract_subset_of_nodes( current_input_index = 0 intermediate = {name} cut_points -= {name} + cached: Dict[int, List[str]] = {} inputs = set(k for k in node.input if k) while not (inputs <= cut_points) and current_node_index >= 0: node = model.graph.node[current_node_index] - if current_input_index == 0 or not node.input: + # node inputs including hidden ones + if current_node_index in cached: + node_inputs = cached[current_node_index] + else: + set_inputs = set(i for i in node.input if i) + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + set_inputs |= get_hidden_inputs(att.g) + node_inputs = list(set_inputs) + cached[current_node_index] = node_inputs + # processing + if current_input_index == 0 or not node_inputs: needs = [o for o in node.output if o in intermediate and o not in cut_points] if needs: selected.add(current_node_index) - if not node.input: + if not node_inputs: current_node_index -= 1 current_input_index = 0 continue @@ -1255,15 +1318,16 @@ def extract_subset_of_nodes( current_node_index -= 1 current_input_index = 0 continue - assert current_input_index < len(node.input), ( - f"current_input_index={current_input_index} but node.input={node.input}, " + # more intermediate results + assert current_input_index < len(node_inputs), ( + f"current_input_index={current_input_index} but node_inputs={node_inputs}, " f"node={pretty_onnx(node)}" ) - res = node.input[current_input_index] + res = node_inputs[current_input_index] if res not in cut_points: intermediate.add(res) current_input_index += 1 - if current_input_index >= len(node.input): + if current_input_index >= len(node_inputs): current_node_index -= 1 current_input_index = 0 @@ -1296,8 +1360,13 @@ def _mkv_(name, itype, irank): not_known: Set[str] = set() for node in nodes[::-1]: - not_known -= set(node.output) - not_known |= set(node.input) + not_known -= {o for o in node.output if o} + not_known |= {i for i in node.input if i} + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + not_known |= get_hidden_inputs(att.g) model = oh.make_model( oh.make_graph( @@ -1310,3 +1379,337 @@ def _mkv_(name, itype, irank): opset_imports=opset_imports, ) return model + + +def get_tensor_shape( + obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto], +) -> Optional[List[Optional[Union[int, str]]]]: + """ + Returns the shape if that makes sense for this object. + """ + if isinstance(obj, ValueInfoProto): + return get_tensor_shape(obj.type) + elif not isinstance(obj, onnx.TypeProto): + raise TypeError(f"Unexpected type {type(obj)!r}.") + if not obj.tensor_type.HasField("shape"): + return None + shape = [] + for d in obj.tensor_type.shape.dim: + v = d.dim_value if d.dim_value > 0 else d.dim_param + shape.append(v) + if not shape: + return shape + return [None if s in (0, "") else s for s in shape] + + +def _enumerate_model_node_outputs( + model: ModelProto, add_node: bool = False, order: bool = False +) -> Iterable[Union[str, Tuple[str, NodeProto]]]: + """ + Enumerates all the nodes of a model. + + :param model: :epkg:`ONNX` graph + :param add_node: if False, the function enumerates + all output names from every node, otherwise, it + enumerates tuple (output name, node) + :param order: goes through outputs following the graph order + :return: enumerator + """ + assert hasattr(model, "graph"), "Parameter model is not an ONNX model but {type(model)}" + if order: + edges = [] + d_order = {} + node_names = {} + for inp in model.graph.input: + d_order[0, inp.name] = 0 + for node in model.graph.node: + d_order[1, node.name] = 0 + for i in node.input: + edges.append(("in", i, node.name)) + for o in node.output: + edges.append(("out", o, node.name)) + node_names[o] = node + d_order[0, o] = 0 + + modif = 1 + n_iter = 0 + while modif > 0 and n_iter <= len(model.graph.node): + modif = 0 + n_iter += 1 + for kind, data_name, node_name in edges: + if kind == "in": + if (0, data_name) not in d_order: + continue + if d_order[0, data_name] + 1 > d_order[1, node_name]: + modif += 1 + d_order[1, node_name] = d_order[0, data_name] + 1 + else: + if d_order[1, node_name] + 1 > d_order[0, data_name]: + modif += 1 + d_order[0, data_name] = d_order[1, node_name] + 1 + + orders = [(v, k) for k, v in d_order.items()] + orders.sort() + + for _, k in orders: + if k[0] == 1: + continue + out = k[1] + if out not in node_names: + continue + yield (out, node_names[out]) if add_node else out + else: + for node in model.graph.node: + for out in node.output: + yield (out, node) if add_node else out + + +def onnx_remove_node_unused( + graph: Union[onnx.GraphProto, onnx.FunctionProto], recursive=True +) -> Union[onnx.GraphProto, onnx.FunctionProto]: + """ + Removes unused nodes of the graph. An unused node + is not involved in the output computation. + + :param onnx_model: onnx model + :param recursive: looks into subgraphs + :return: new Graph + """ + is_function = isinstance(graph, FunctionProto) + + # mark outputs + marked: Dict[str, Set[str]] = ( + {o: set() for o in graph.output} + if is_function + else {o.name: set() for o in graph.output} + ) + nodes = list(graph.node) + + # mark node output + for node in reversed(nodes): + used = False + for o in node.output: + if o and o in marked: + for i in get_all_node_inputs(node): + marked[o].add(i) + used = True + if used: + for i in get_all_node_inputs(node): + marked[i] = set() + + # removed nodes + removed = set() + marked_set = set(marked) + for ind, node in enumerate(nodes): + if not ({o for o in node.output if o} & marked_set): + removed.add(ind) + + if not is_function: + initializers = [i for i in graph.initializer if i.name in marked] + sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked] + new_nodes = [node for i, node in enumerate(nodes) if i not in removed] + + # Finally create the new graph. + if is_function: + return oh.make_function( + graph.domain, + graph.name, + graph.input, + graph.output, + new_nodes, + opset_imports=graph.opset_import, + attributes=graph.attribute, + doc_string=graph.doc_string, + ) + new_graph = oh.make_graph( + new_nodes, + graph.name, + graph.input, + graph.output, + initializers, + sparse_initializers, + ) + new_graph.value_info.extend(graph.value_info) + return new_graph + + +def select_model_inputs_outputs( + model: ModelProto, + outputs: Optional[List[str]] = None, + inputs: Optional[List[str]] = None, + infer_shapes: bool = True, + overwrite: Optional[Dict[str, Any]] = None, + remove_unused: bool = True, + verbose: int = 0, +): + """ + Takes a model and changes its outputs. + + :param model: :epkg:`ONNX` model + :param inputs: new inputs, same ones if None + :param outputs: new outputs, same ones if None + :param infer_shapes: infer inputs and outputs shapes + :param overwrite: overwrite type and shapes for + inputs or outputs, *overwrite* is a + dictionary `{'name': (numpy dtype, shape)}` + :param remove_unused: remove unused nodes from the graph + :param verbose: display information while converting + :return: modified model + + The function removes unneeded nodes. + + The following example shows how to change the inputs of model + to bypass the first nodes. Shape inferences fails to determine + the new inputs type. They need to be overwritten. + `verbose=1` shows the number of deleted nodes. + + :: + + import onnx + from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs + + onx = onnx.load(path) + onx2 = select_model_inputs_outputs( + onx, inputs=["a", "b"], + infer_shapes=True, verbose=1, + overwrite={'a': (numpy.int32, None), 'b': (numpy.int64, None)}) + onnx.save(onx2, path2) + """ + if not isinstance(model, ModelProto): + raise TypeError(f"Unexpected type {type(model)} for model.") + if inputs is not None and not isinstance(inputs, list): + inputs = [inputs] + if outputs is not None and not isinstance(outputs, list): + outputs = [outputs] + if inputs is None: + inputs = [i.name for i in model.graph.input] + if outputs is None: + outputs = [o.name for o in model.graph.output] + + mark_var = {} + for out in _enumerate_model_node_outputs(model): + mark_var[out] = 0 + for inp in inputs: + mark_var[inp] = 0 + for out in outputs: + assert out in mark_var, f"Output {out!r} not found in model." + mark_var[out] = 1 + + nodes = list(model.graph.node[::-1]) + mark_op = {} + for node in list(nodes): + mark_op[id(node)] = 0 + + # We mark all the nodes we need to keep. + nb = 1 + while nb > 0: + nb = 0 + for node in nodes: + if mark_op[id(node)] == 1: + continue + mod = False + for out in node.output: + if mark_var[out] == 1: + mark_op[id(node)] = 1 + mod = True + break + if not mod: + continue + + node_inputs = get_all_node_inputs(node) + + nb += 1 + for inp in node_inputs: + if inp in inputs: + continue + if mark_var.get(inp, 0) == 1: + continue + mark_var[inp] = 1 + nb += 1 + + # All nodes verifies mark_op[node.name] == 1 + keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1] + + known_shapes = {} + if infer_shapes: + shapes = onnx.shape_inference.infer_shapes(model) + for shape in shapes.graph.value_info: + known_shapes[shape.name] = shape.type + for shape in shapes.graph.input: + known_shapes[shape.name] = shape.type + for shape in shapes.graph.output: + known_shapes[shape.name] = shape.type + else: + for shape in model.graph.input: + known_shapes[shape.name] = shape.type + for shape in model.graph.output: + known_shapes[shape.name] = shape.type + + var_in = [] + for name in inputs: + if overwrite is not None and name in overwrite: + dtype, shape = overwrite[name] + proto_dtype = np_dtype_to_tensor_dtype(dtype) + value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + elif name in known_shapes: + info = known_shapes[name].tensor_type + proto_dtype = info.elem_type + if proto_dtype == 0: + value_info = ValueInfoProto() + value_info.name = name + else: + shape = get_tensor_shape(known_shapes[name]) + value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + else: + value_info = ValueInfoProto() + value_info.name = name + var_in.append(value_info) + + var_out = [] + for name in outputs: + if overwrite is not None and name in overwrite: + dtype, shape = overwrite[name] + proto_dtype = np_dtype_to_tensor_dtype(dtype) + value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + elif name in known_shapes: + info = known_shapes[name].tensor_type + proto_dtype = info.elem_type + if proto_dtype == 0: + value_info = ValueInfoProto() + value_info.name = name + else: + shape = get_tensor_shape(known_shapes[name]) + value_info = oh.make_tensor_value_info(name, proto_dtype, shape) + else: + value_info = ValueInfoProto() + value_info.name = name + var_out.append(value_info) + + graph = oh.make_graph( + keep_nodes, + model.graph.name, + var_in, + var_out, + model.graph.initializer, + sparse_initializer=model.graph.sparse_initializer, + ) + if remove_unused: + graph = onnx_remove_node_unused(graph, recursive=False) + onnx_model = oh.make_model(graph, functions=model.functions) + onnx_model.ir_version = model.ir_version + onnx_model.producer_name = model.producer_name + onnx_model.producer_version = model.producer_version + onnx_model.domain = model.domain + onnx_model.model_version = model.model_version + onnx_model.doc_string = model.doc_string + if model.metadata_props: + values = {p.key: p.value for p in model.metadata_props} + oh.set_model_props(onnx_model, values) + + del onnx_model.opset_import[:] + for oimp in model.opset_import: + op_set = onnx_model.opset_import.add() + op_set.domain = oimp.domain + op_set.version = oimp.version + + return onnx_model diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 11fc05a1..f7c0bbfc 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -811,7 +811,8 @@ def torch_deepcopy(value: Any) -> Any: if isinstance(value, tuple): return tuple(torch_deepcopy(v) for v in value) if isinstance(value, list): - return [torch_deepcopy(v) for v in value] + if type(value) is list: + return [torch_deepcopy(v) for v in value] if isinstance(value, set): return {torch_deepcopy(v) for v in value} if isinstance(value, dict): diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index 6d7f20d6..ffa511d6 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -18,10 +18,11 @@ import onnxruntime from ..helpers import string_type from ..helpers.onnx_helper import ( - pretty_onnx, + get_hidden_inputs, dtype_to_tensor_dtype, - to_array_extended, np_dtype_to_tensor_dtype, + to_array_extended, + pretty_onnx, ) from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype from ..helpers.ort_session import ( @@ -472,39 +473,15 @@ def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]: yield from self.enumerate_nodes(att.g.node) yield node - @classmethod - def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]: - """ - Returns the hidden inputs (inputs coming from an upper context) - used by a subgraph. - """ - hidden = set() - memo = ( - {i.name for i in graph.initializer} - | {i.name for i in graph.sparse_initializer} - | {i.name for i in graph.input} - ) - for node in graph.node: - for i in node.input: - if i not in memo: - hidden.add(i) - for att in node.attribute: - if att.type == AttributeProto.GRAPH and att.g: - hid = cls._get_hidden_inputs(att.g) - less = set(h for h in hid if h not in memo) - hidden |= less - memo |= set(node.output) - return hidden - @classmethod def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]: - """Calls multiple _get_hidden_inputs on every attribute.""" + """Calls multiple get_hidden_inputs on every attribute.""" if node.op_type not in {"Loop", "Scan", "If"}: return set() hidden = set() for att in node.attribute: if att.type == AttributeProto.GRAPH: - hidden |= cls._get_hidden_inputs(att.g) + hidden |= get_hidden_inputs(att.g) return hidden - (hidden & set(node.input)) def _get_sess( @@ -624,7 +601,7 @@ def _get_sess_init_subgraph( value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) vinputs.append(value) - reduced_set = self._get_hidden_inputs(g) + reduced_set = get_hidden_inputs(g) for i, v in context.items(): if i in reduced_set and i not in unique_names: unique_names.add(i) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py index 8be65493..98cdabde 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py @@ -118,6 +118,7 @@ def patched_sdpa_attention_forward( torch._check(value.shape[1] > 0) torch._check(value.shape[2] > 0) torch._check(value.shape[3] > 0) + return ( torch.nn.functional.scaled_dot_product_attention( query, diff --git a/onnx_diagnostic/torch_onnx/runtime_info.py b/onnx_diagnostic/torch_onnx/runtime_info.py index 86664dd9..5d3ed0a4 100644 --- a/onnx_diagnostic/torch_onnx/runtime_info.py +++ b/onnx_diagnostic/torch_onnx/runtime_info.py @@ -4,6 +4,7 @@ import torch from ..api import TensorLike from ..helpers import string_type +from ..helpers.onnx_helper import get_hidden_inputs class RuntimeValueKind(enum.IntEnum): @@ -151,30 +152,6 @@ def is_initializer(self) -> bool: return self.kind == RuntimeValueKind.INITIALIZER -def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: - """ - Returns the hidden inputs (inputs coming from an upper context) - used by a subgraph. - """ - hidden = set() - memo = ( - set(i.name for i in graph.initializer) - | set(i.name for i in graph.sparse_initializer) - | set(i.name for i in graph.input) - ) - for node in graph.node: - for i in node.input: - if i not in memo: - hidden.add(i) - for att in node.attribute: - if att.type == onnx.AttributeProto.GRAPH and att.g: - hid = get_hidden_inputs(att.g) - less = set(h for h in hid if h not in memo) - hidden |= less - memo |= set(node.output) - return hidden - - def set_is_shape( node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None ) -> List[str]: diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 4bdfed0a..c6018e90 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -7,7 +7,13 @@ import torch from ..helpers import string_type, string_diff, max_diff, flatten_object from ..helpers.onnx_helper import pretty_onnx -from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype +from ..helpers.torch_helper import ( + to_numpy, + from_numpy, + to_tensor, + torch_dtype_to_onnx_dtype, + torch_deepcopy, +) from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator from .sbs_dataclasses import ( @@ -188,7 +194,7 @@ def _loop_onnx_node( print(f"[run_aligned] feeds={string_type(feeds, **str_kws)}") begin = time.perf_counter() try: - res = ref.run(None, feeds) # type: ignore[attr-defined] + res = ref.run(None, torch_deepcopy(feeds)) # type: ignore[attr-defined] except Exception as e: raise RuntimeError( f"Unable to run node {node.op_type}, domain={node.domain} " @@ -241,7 +247,7 @@ def _loop_onnx_node( f"[run_aligned] feeds for second run=" f"{string_type(new_feeds, **str_kws)}" ) - cross = ref.run(None, new_feeds) + cross = ref.run(None, torch_deepcopy(new_feeds)) if verbose > 1: print(f"[run_aligned] got for second run={string_type(cross, **str_kws)}") # Gemm = torch.nn.function.linear, in that case, we just run it as well diff --git a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py index 2b7255d2..b0ef5a5e 100644 --- a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py +++ b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py @@ -11,7 +11,12 @@ import onnx import numpy as np import torch -from ..helpers.onnx_helper import extract_subset_of_nodes, make_submodel, from_array_extended +from ..helpers.onnx_helper import ( + extract_subset_of_nodes, + make_submodel, + from_array_extended, + select_model_inputs_outputs, +) from ..helpers.torch_helper import torch_dtype_to_onnx_dtype @@ -61,12 +66,16 @@ class ReplayConfiguration: :param selected_names: list of results names to dump :param selected_op_types: list of onnx operators to dump :param threshold: only keep those whose discrepancies is greater than that threshold + :param dump_prefix_model: after dumping the smallest model able to replicate + one given output, if also dumps the models producing the inputs + and the outputs truncated from the big one """ dump_folder: str selected_names: Optional[Set[str]] = None selected_op_types: Optional[Set[str]] = None threshold: float = 0.1 + dump_prefix_model: bool = False def __post_init__(self): assert self.dump_folder, "dump_folder is empty and this is not allowed for the replay" @@ -297,7 +306,7 @@ def dump( del submodel.graph.input[:] submodel.graph.input.extend(new_inputs) if verbose: - print(f"[ReplayConfiguration.dump] removed input {removed_inputs}") + print(f"[ReplayConfiguration.dump] removed inputs {removed_inputs}") print(f"[ReplayConfiguration.dump] final model inputs {input_names}") onnx.save(submodel, os.path.join(folder, "model.onnx")) @@ -318,6 +327,30 @@ def dump( ) with open(os.path.join(folder, "replay.py"), "w") as f: f.write(self.get_replay_code()) + + if self.dump_prefix_model: + main_inputs = { + i.name: onnx_inputs.get(i.name, torch_inputs.get(i.name, None)) + for i in model.graph.input + } + # only saving onnx inputs, torch should be the same + torch.save(main_inputs, os.path.join(folder, "onnx_main_inputs.pt")) + + model_inputs_file = os.path.join(folder, "model.inputs.onnx") + exclude = {i.name for i in model.graph.input} | { + i.name for i in model.graph.initializer + } + model_inputs = select_model_inputs_outputs( + model, outputs=[i.name for i in submodel.graph.input if i.name not in exclude] + ) + onnx.save(model_inputs, model_inputs_file) + + model_outputs_file = os.path.join(folder, "model.outputs.onnx") + model_outputs = select_model_inputs_outputs( + model, outputs=[i.name for i in submodel.graph.output] + ) + onnx.save(model_outputs, model_outputs_file) + if verbose: print(f"[ReplayConfiguration.dump] done {folder!r}") return folder