diff --git a/_unittests/ut_reference/test_onnxruntime_evaluator.py b/_unittests/ut_reference/test_onnxruntime_evaluator.py index 8454a9c4..3a9871ec 100644 --- a/_unittests/ut_reference/test_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_onnxruntime_evaluator.py @@ -259,6 +259,31 @@ def test_skip_layer_normalization(self): got = rt.run(None, feeds) self.assertEqualAny(expected, got, atol=1e-4) + @hide_stdout() + def test_skip_simplified_layer_normalization(self): + node = oh.make_node( + "SkipSimplifiedLayerNormalization", + ["x", "skip", "beta", "gamma"], + ["Z", "", "", "bias"], + epsilon=1.0e-5, + domain="com.microsoft", + ) + feeds = dict( + x=self._range(2, 3, 8), + skip=self._range(2, 3, 8, bias=3), + beta=self._range(8, bias=1), + gamma=self._range(8, bias=2), + ) + rt = OnnxruntimeEvaluator(node, verbose=10, opsets={"": 22}) + got = rt.run(None, feeds) + self.assertEqual(len(got), 2) + self.assertIsInstance(got[0], np.ndarray) + self.assertIsInstance(got[1], np.ndarray) + self.assertEqual(got[0].shape, feeds["x"].shape) + self.assertEqual(got[0].dtype, feeds["x"].dtype) + self.assertEqual(got[1].shape, feeds["x"].shape) + self.assertEqual(got[1].dtype, feeds["x"].dtype) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index 8ac90321..083a2ee8 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -278,9 +278,11 @@ def run( outputs = self._run_local(node, inputs, results) else: outputs = self._run(node, inputs, results) - for name, value in zip(node.output, outputs): - if name == "": - continue + node_output = [o for o in node.output if o] + assert len(node_output) == len( + outputs + ), f"Length mismatch between node output={node.output} and outputs={outputs}" + for name, value in zip(node_output, outputs): self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] assert isinstance(name, str), f"unexpected type for name {type(name)}" results[name] = value @@ -384,6 +386,11 @@ def _make_model_proto( onx = shi.infer_shapes(onx) return onx + def _make_model_outputs( + self, node: NodeProto, inputs: List[ValueInfoProto] + ) -> Tuple[List[NodeProto], List[ValueInfoProto]]: + return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o] + @classmethod def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]: """ @@ -434,6 +441,7 @@ def _get_sess( node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape ) ] + prenodes = [] # type: ignore[var-annotated] else: unique_names = set() vinputs = [] @@ -447,9 +455,9 @@ def _get_sess( vinputs.append(value) # no need to run shape inference - voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] + prenodes, voutputs = self._make_model_outputs(node, vinputs) - onx = self._make_model_proto([node], vinputs, voutputs) + onx = self._make_model_proto([*prenodes, node], vinputs, voutputs) if node.op_type in {"Shape", "Size"}: on_cpu = True diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 6b4f2359..30576a4a 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -492,11 +492,7 @@ def _loop_cmp( f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}" ) else: - print( - f"[run_align-dx] discrepancies " - f"{string_diff(d, with_shape=True, with_device=True)} - " - f"[{to}/{o}]" - ) + print(f"[run_align-dx] discrepancies {string_diff(d)} - [{to}/{o}]") return (i, i_onnx, o, to, string_type(torch_results[to], **str_kws), d) return None