diff --git a/.github/workflows/check-urls.yml b/.github/workflows/check-urls.yml index faf0ea99..77c84160 100644 --- a/.github/workflows/check-urls.yml +++ b/.github/workflows/check-urls.yml @@ -42,6 +42,6 @@ jobs: print_all: false timeout: 2 retry_count# : 2 - exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/ - exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/ + exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/,https://docs.scipy.org/doc/scipy/ + exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/,https://docs.scipy.org/doc/scipy/ # force_pass : true diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index aa065ad3..814f5d08 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -5,6 +5,7 @@ Change Logs +++++ * :pr:`330`: fixes access rope_parameters for ``transformers>=5`` +* :pr:`329`: supports lists with OnnxruntimeEvaluator * :pr:`326`: use ConcatFromSequence in LoopMHA with the loop * :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies * :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index a26b85dd..69aff474 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -402,7 +402,9 @@ def test_enumerate_results_loop(self): new_axis=0, ), ], - ) + ), + ir_version=10, + opset_imports=[oh.make_opsetid("", 22)], ) res = list(enumerate_results(model, "slice_start", verbose=2)) self.assertEqual(len(res), 2) diff --git a/_unittests/ut_reference/test_onnxruntime_evaluator.py b/_unittests/ut_reference/test_onnxruntime_evaluator.py index f871ee57..85d30ed3 100644 --- a/_unittests/ut_reference/test_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_onnxruntime_evaluator.py @@ -20,6 +20,7 @@ TFLOAT = onnx.TensorProto.FLOAT +TINT64 = onnx.TensorProto.INT64 class TestOnnxruntimeEvaluator(ExtTestCase): @@ -319,6 +320,123 @@ def test_function_proto_with_kwargs(self): got = sess.run(None, feeds) self.assertEqualArray(expected, got[0], atol=1e-5) + @hide_stdout() + def test_ort_eval_loop_seq(self): + x = np.array([1, 2, 3, 4, 5]).astype(np.float32) + _mkv_ = oh.make_tensor_value_info + model = oh.make_model( + graph=oh.make_graph( + name="loop_test", + inputs=[ + oh.make_tensor_value_info("trip_count", TINT64, ["a"]), + oh.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + ], + outputs=[oh.make_tensor_value_info("res", TFLOAT, [])], + nodes=[ + oh.make_node("SequenceEmpty", [], ["seq_empty"], dtype=TFLOAT), + oh.make_node( + "Loop", + inputs=["trip_count", "cond", "seq_empty"], + outputs=["seq_res"], + body=oh.make_graph( + [ + oh.make_node( + "Identity", inputs=["cond_in"], outputs=["cond_out"] + ), + oh.make_node( + "Constant", + inputs=[], + outputs=["x"], + value=oh.make_tensor( + name="const_tensor_x", + data_type=TFLOAT, + dims=x.shape, + vals=x.flatten().astype(float), + ), + ), + oh.make_node( + "Constant", + inputs=[], + outputs=["one"], + value=oh.make_tensor( + name="const_tensor_one", + data_type=TINT64, + dims=(), + vals=[1], + ), + ), + oh.make_node( + "Constant", + inputs=[], + outputs=["slice_start"], + value=oh.make_tensor( + name="const_tensor_zero", + data_type=TINT64, + dims=(1,), + vals=[0], + ), + ), + oh.make_node( + "Add", inputs=["iter_count", "one"], outputs=["end"] + ), + oh.make_node( + "Constant", + inputs=[], + outputs=["axes"], + value=oh.make_tensor( + name="const_tensor_axes", + data_type=TINT64, + dims=(1,), + vals=[0], + ), + ), + oh.make_node( + "Unsqueeze", inputs=["end", "axes"], outputs=["slice_end"] + ), + oh.make_node( + "Slice", + inputs=["x", "slice_start", "slice_end"], + outputs=["slice_out"], + ), + oh.make_node( + "SequenceInsert", + inputs=["seq_in", "slice_out"], + outputs=["seq_out"], + ), + ], + "loop_body", + [ + _mkv_("iter_count", TINT64, []), + _mkv_("cond_in", onnx.TensorProto.BOOL, []), + oh.make_tensor_sequence_value_info("seq_in", TFLOAT, None), + ], + [ + _mkv_("cond_out", onnx.TensorProto.BOOL, []), + oh.make_tensor_sequence_value_info("seq_out", TFLOAT, None), + ], + ), + ), + oh.make_node( + "ConcatFromSequence", + inputs=["seq_res"], + outputs=["res"], + axis=0, + new_axis=0, + ), + ], + ), + ir_version=10, + opset_imports=[oh.make_opsetid("", 22)], + ) + ev = OnnxruntimeEvaluator(model, verbose=10) + feeds = dict(trip_count=torch.tensor([3], dtype=torch.int64), cond=torch.tensor(True)) + got = ev.run(None, feeds) + self.assertEqual((6,), got[0].shape) + self.assertEqualArray( + torch.tensor([1.0, 1.0, 2.0, 1.0, 2.0, 3.0], dtype=torch.float32), got[0] + ) + self.assertIsInstance(got[0], torch.Tensor) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_reference/test_torch_onnx_evaluator.py b/_unittests/ut_reference/test_torch_onnx_evaluator.py index ef62517f..ce3d0fcd 100644 --- a/_unittests/ut_reference/test_torch_onnx_evaluator.py +++ b/_unittests/ut_reference/test_torch_onnx_evaluator.py @@ -1123,7 +1123,9 @@ def test_loop(self): new_axis=0, ), ], - ) + ), + ir_version=10, + opset_imports=[oh.make_opsetid("", 22)], ) self._finalize_test( model, torch.tensor(5, dtype=torch.int64), torch.tensor(1, dtype=torch.bool) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index ac994033..98aecb63 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -8,8 +8,10 @@ ignore_errors, requires_cuda, ) +from onnx_diagnostic.helpers.rt_helper import make_feeds from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str +from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5 from onnx_diagnostic.torch_onnx.sbs import run_aligned from onnx_diagnostic.torch_onnx.sbs_dataclasses import RunAlignedRecord, ReplayConfiguration from onnx_diagnostic.export.api import to_onnx @@ -671,6 +673,124 @@ def forward(self, x): ) self.clean_dump() + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") + @hide_stdout() + def test_sbs_with_loops(self): + import torch + from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( + PLUGS_Qwen25, + ) + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( + qwen_sdpa_attention_loopmha_versatile, + ) + + class Model(torch.nn.Module): + def forward(self, query, key, value, seq_lens): + rg1 = torch.arange(4, dtype=torch.int32).unsqueeze(0) + rg0 = torch.arange(4, dtype=torch.int32).unsqueeze(1) + mask = (rg0 <= rg1).flatten().reshape((1, -1, 1, 1)).to(query.dtype) + qs = query * mask + ks = key * mask + vs = value * mask + attn_output = qwen_sdpa_attention_loopmha_versatile( + qs, + ks, + vs, + seq_lens, + 0.11, + 16, + ( + onnx.TensorProto.FLOAT + if query.dtype == torch.float32 + else ( + onnx.TensorProto.FLOAT16 + if query.dtype == torch.float16 + else onnx.TensorProto.BFLOAT16 + ) + ), + ) + red = attn_output.mean(dim=-1, keepdim=True) + return attn_output - red + + model = Model() + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.tensor( + [ + 0, + 64, + 128, + 192, + 256, + 304, + 368, + 432, + 496, + 560, + 608, + 672, + 736, + 800, + 864, + 912, + 976, + 1040, + 1104, + 1168, + 1216, + 1232, + 1248, + 1264, + 1280, + 1292, + ], + dtype=torch.int64, + ), + ) + expected = model(*inputs) + ds = ({2: "seq_length"}, {2: "seq_length"}, {2: "seq_length"}, {0: "num_patches"}) + onnx_file = self.get_dump_file("test_sbs_with_loops.onnx") + ep_file = self.get_dump_file("test_sbs_with_loops") + to_onnx( + model, + inputs, + dynamic_shapes=ds, + filename=onnx_file, + save_ep=(ep_file, 2**28), + exporter="custom", + onnx_plugs=PLUGS_Qwen25, + target_opset=22, + ) + input_file = ep_file + ".input.pt" + ep_file = ep_file + ".ep.pt2" + self.assertExists(onnx_file) + self.assertExists(ep_file) + self.assertExists(input_file) + sess = self.check_ort(onnx_file) + input_names = [i.name for i in sess.get_inputs()] + feeds = make_feeds(input_names, inputs, use_numpy=True) + got = sess.run(None, feeds) + self.assertEqualArray(expected, got[0], atol=1e-3) + # sbs + ep = torch.export.load(ep_file) + onx = onnx.load(onnx_file) + kwargs = make_feeds(input_names, inputs, use_numpy=False) + results = list( + run_aligned( + ep, + onx, + kwargs=kwargs, + run_cls=OnnxruntimeEvaluator, + verbose=11, + use_tensor=True, + ), + ) + df = pandas.DataFrame(list(results)).dropna(axis=1, how="all") + df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx")) + # self.clean_dump() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 24b0a53c..bf2e09cd 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -1111,7 +1111,10 @@ def check_ort( ) -> "onnxruntime.InferenceSession": # noqa: F821 from onnxruntime import InferenceSession - return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"]) + return InferenceSession( + onx if isinstance(onx, str) else onx.SerializeToString(), + providers=["CPUExecutionProvider"], + ) def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None): """In the name""" diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 8d26d99d..56f260c4 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -108,7 +108,10 @@ def __init__( session_options, providers=providers, ) - except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e: + except ( + onnxruntime.capi.onnxruntime_pybind11_state.Fail, + onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph, + ) as e: if isinstance(sess, onnx.ModelProto): debug_path = "_debug_InferenceSession_last_failure.onnx" onnx.save( diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index edae7f28..6d7f20d6 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -6,6 +6,7 @@ FunctionProto, ModelProto, NodeProto, + TensorProto, TypeProto, ValueInfoProto, helper as oh, @@ -16,7 +17,13 @@ from onnx.defs import onnx_opset_version import onnxruntime from ..helpers import string_type -from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended +from ..helpers.onnx_helper import ( + pretty_onnx, + dtype_to_tensor_dtype, + to_array_extended, + np_dtype_to_tensor_dtype, +) +from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype from ..helpers.ort_session import ( InferenceSessionForTorch, InferenceSessionForNumpy, @@ -31,6 +38,54 @@ Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto] +class OnnxList(list): + """Defines a list for the runtime.""" + + def __init__(self, itype: Union[list, int]): + super().__init__() + if isinstance(itype, int): + self.itype = itype + self.dtype = onnx_dtype_to_torch_dtype(itype) + else: + assert itype, "The list cannot be created with an empty list." + self.itype = ( + np_dtype_to_tensor_dtype(itype[0].dtype) + if isinstance(itype[0], np.ndarray) + else torch_dtype_to_onnx_dtype(itype[0].dtype) + ) + self.extend(itype) + self.dtype = itype[0].dtype + self.shape = "OnnxList" + + def get_device(self): + "Returns the device of the first tensor." + assert len(self) > 0, "Cannot access the device for an empty list." + return self[0].get_device() if hasattr(self[0], "get_device") else -1 + + def numpy(self): + "Creates a new list with all tensors on numpy or self it is already the case." + if all(isinstance(v, np.ndarray) for v in self): + return self + return OnnxList([v.detach().cpu().numpy() for v in self]) + + def to(self, tensor_like) -> "OnnxList": + "Creates a new list with all tensors on numpy or pytorch depending on `tensor_like`." + if isinstance(tensor_like, np.ndarray): + return self + import torch + + return OnnxList( + [ + torch.from_numpy(t).to(tensor_like.device) if isinstance(t, np.ndarray) else t + for t in self + ] + ) + + def clone(self) -> "OnnxList": + "Clone (torch)." + return OnnxList([t.clone() for t in self]) if len(self) > 0 else OnnxList(self.itype) + + class OnnxruntimeEvaluator: """ This class loads an onnx model and the executes one by one the nodes @@ -209,6 +264,8 @@ def output_types(self) -> List[TypeProto]: def _log_arg(self, a: Any) -> Any: if isinstance(a, (str, int, float)): return a + if isinstance(a, OnnxList): + return string_type(a) device = f"D{a.get_device()}:" if hasattr(a, "detach") else "" if hasattr(a, "shape"): prefix = "A:" if hasattr(a, "astype") else "T:" @@ -231,6 +288,12 @@ def _log(self, level: int, pattern: str, *args: Any) -> None: def _is_local_function(self, node: NodeProto) -> bool: return (node.domain, node.op_type) in self.local_functions + def _run_init(self, feed_inputs): + if self.sess_ is None: + assert self.proto, "self.proto is empty" + _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values())) + return self.sess_ + def run( self, outputs: Optional[List[str]], @@ -254,9 +317,7 @@ def run( """ if self.rt_nodes_ is None: # runs a whole - if self.sess_ is None: - assert self.proto, "self.proto is empty" - _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values())) + self._run_init(feed_inputs) assert self.sess_, "mypy not happy" return self.sess_.run(outputs, feed_inputs) if outputs is None: @@ -283,7 +344,7 @@ def run( if node.op_type == "If" and node.domain == "": outputs = self._run_if(node, inputs, results) elif node.op_type in {"Scan", "Loop"} and node.domain == "": - outputs = self._run_scan(node, inputs, results) + outputs = self._run_scan_or_loop(node, inputs, results) elif self._is_local_function(node): outputs = self._run_local(node, inputs, results) else: @@ -412,35 +473,38 @@ def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]: yield node @classmethod - def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]: + def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]: """ Returns the hidden inputs (inputs coming from an upper context) used by a subgraph. """ hidden = set() - memo = set(i.name for i in graph.initializer) - memo |= set(i.name for i in graph.sparse_initializer) + memo = ( + {i.name for i in graph.initializer} + | {i.name for i in graph.sparse_initializer} + | {i.name for i in graph.input} + ) for node in graph.node: for i in node.input: if i not in memo: hidden.add(i) for att in node.attribute: if att.type == AttributeProto.GRAPH and att.g: - hid = self._get_hidden_inputs(att.g) + hid = cls._get_hidden_inputs(att.g) less = set(h for h in hid if h not in memo) hidden |= less memo |= set(node.output) return hidden @classmethod - def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]: + def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]: """Calls multiple _get_hidden_inputs on every attribute.""" if node.op_type not in {"Loop", "Scan", "If"}: return set() hidden = set() for att in node.attribute: if att.type == AttributeProto.GRAPH: - hidden |= self._get_hidden_inputs(att.g) + hidden |= cls._get_hidden_inputs(att.g) return hidden - (hidden & set(node.input)) def _get_sess( @@ -472,6 +536,18 @@ def _get_sess( ) ] prenodes = [] # type: ignore[var-annotated] + elif node.op_type == "ConcatFromSequence" and node.domain == "": + # We force the type to be a boolean. + vinputs = [ + oh.make_value_info( + node.input[0], + type_proto=oh.make_sequence_type_proto( + oh.make_tensor_type_proto(elem_type=inputs[0].itype, shape=None) + ), + ) + ] + voutputs = [oh.make_tensor_value_info(node.output[0], inputs[0].itype, None)] + prenodes = [] # type: ignore[var-annotated] else: unique_names = set() vinputs = [] @@ -535,7 +611,17 @@ def _get_sess_init_subgraph( if i == "" or i in unique_names: continue unique_names.add(i) - value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) + if isinstance(it, OnnxList): + value = oh.make_value_info( + i, + type_proto=oh.make_sequence_type_proto( + oh.make_tensor_type_proto( + elem_type=dtype_to_tensor_dtype(it.dtype), shape=None + ) + ), + ) + else: + value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) vinputs.append(value) reduced_set = self._get_hidden_inputs(g) @@ -544,6 +630,10 @@ def _get_sess_init_subgraph( unique_names.add(i) value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape) vinputs.append(value) + assert len(reduced_set & set(context)) == len(reduced_set), ( + f"Missing hidden inputs {sorted(reduced_set)} from context={sorted(context)} " + f"(len(inputs)={len([i for i in inputs if i])}) for node {pretty_onnx(node)}" + ) return vinputs def _get_sess_if( @@ -592,6 +682,14 @@ def _get_sess_local( def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]: """Runs a node.""" + if node.op_type[0] == "S": + if node.op_type == "SequenceEmpty": + dtype = TensorProto.FLOAT + for att in node.attribute: + if att.name == "dtype": + dtype = att.i + return [OnnxList(itype=dtype)] + types = [(None if a is None else (a.dtype, a.shape)) for a in inputs] key = (id(node), *types) if key in self._cache: @@ -609,8 +707,22 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L continue feeds[i] = val assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" + + if node.op_type[0] == "C": + if node.op_type == "ConcatFromSequence": + res = sess.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr] + if isinstance(inputs[0][0], np.ndarray): + return list(res) + import torch + + return [torch.from_numpy(r).to(inputs[0][0].device) for r in res] + outputs = list(sess.run(None, feeds)) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" + assert not any(type(v) is list for v in outputs), ( + f"One output type is a list, this should not be allowed, " + f"node.op_type={node.op_type}, feeds={string_type(feeds, with_shape=True)}" + ) return outputs def _run_if( @@ -636,7 +748,7 @@ def _run_if( assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs - def _get_sess_scan( + def _get_sess_scan_or_loop( self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: g = None @@ -671,10 +783,26 @@ def _get_sess_scan( ) return onx, sess - def _run_scan( + def feeds_to_numpy(self, feeds): + new_feeds = {} + for k, v in feeds.items(): + if hasattr(v, "detach"): + new_feeds[k] = v.detach().cpu().numpy() + elif isinstance(v, OnnxList): + new_feeds[k] = v.numpy() + else: + new_feeds[k] = v + return new_feeds + + def _run_scan_or_loop( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: """Runs a node Scan.""" + assert not any(type(i) is list for i in inputs), ( + f"One input is a list but it should an OnnxList, " + f"node.op_type={node.op_type!r}, node.input={node.input}, " + f"inputs={string_type(inputs, with_shape=True)}" + ) feeds = dict(zip(node.input, inputs)) feeds.update(results) name = "body" @@ -682,10 +810,21 @@ def _run_scan( if key in self._cache: sess = self._cache[key][1] else: - self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results) + self._cache[key] = _onx, sess = self._get_sess_scan_or_loop( + node, name, inputs, results + ) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" feeds = {name: results[name] for name in sess.input_names} + if node.op_type == "Loop" and any(isinstance(v, OnnxList) for v in feeds.values()): + # This operator uses sequence. onnxruntime does not play well with sequence. + sess._run_init(feeds) # type: ignore[union-attr] + outputs = sess.sess_.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr] + return [ + (OnnxList(v).to(feeds[node.input[0]]) if isinstance(v, list) else v) + for v in outputs + ] + outputs = sess.run(None, feeds) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 9400c197..7cac5e79 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -9,6 +9,7 @@ from ..helpers.onnx_helper import pretty_onnx from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node +from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator from .sbs_dataclasses import ( ReplayConfiguration, RunAlignedRecord, @@ -26,11 +27,11 @@ def _check_tensor_(use_tensor, name, obj, flip_type=False): if isinstance(obj, torch.Tensor): obj = to_numpy(obj) - assert not use_tensor or isinstance(obj, torch.Tensor), ( + assert not use_tensor or isinstance(obj, (torch.Tensor, OnnxList)), ( f"Unexpected type {type(obj)} for {name!r}. " f"use_tensor is True so torch.Tensor is expected." ) - assert use_tensor or isinstance(obj, np.ndarray), ( + assert use_tensor or isinstance(obj, (np.ndarray, OnnxList)), ( f"Unexpected type {type(obj)} for {name!r}. " f"use_tensor is False so np.array is expected." ) @@ -175,10 +176,12 @@ def _loop_onnx_node( ref = run_cls(node, **run_cls_kwargs) # We need to clone because the runtime maybe using dlpack to create OrtValue + hidden_inputs = OnnxruntimeEvaluator._get_hidden_node_inputs(node) + all_inputs = [*node.input, *hidden_inputs] if hidden_inputs else node.input feeds = ( - {k: onnx_results[k].clone() for k in node.input if k} + {k: onnx_results[k].clone() for k in all_inputs if k} if use_tensor - else {k: onnx_results[k].copy() for k in node.input if k} + else {k: onnx_results[k].copy() for k in all_inputs if k} ) assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}" if verbose > 1: