Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions _unittests/ut_reference/test_onnxruntime_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,31 @@ def test_skip_layer_normalization(self):
got = rt.run(None, feeds)
self.assertEqualAny(expected, got, atol=1e-4)

@hide_stdout()
def test_skip_simplified_layer_normalization(self):
node = oh.make_node(
"SkipSimplifiedLayerNormalization",
["x", "skip", "beta", "gamma"],
["Z", "", "", "bias"],
epsilon=1.0e-5,
domain="com.microsoft",
)
feeds = dict(
x=self._range(2, 3, 8),
skip=self._range(2, 3, 8, bias=3),
beta=self._range(8, bias=1),
gamma=self._range(8, bias=2),
)
rt = OnnxruntimeEvaluator(node, verbose=10, opsets={"": 22})
got = rt.run(None, feeds)
self.assertEqual(len(got), 2)
self.assertIsInstance(got[0], np.ndarray)
self.assertIsInstance(got[1], np.ndarray)
self.assertEqual(got[0].shape, feeds["x"].shape)
self.assertEqual(got[0].dtype, feeds["x"].dtype)
self.assertEqual(got[1].shape, feeds["x"].shape)
self.assertEqual(got[1].dtype, feeds["x"].dtype)


if __name__ == "__main__":
unittest.main(verbosity=2)
18 changes: 13 additions & 5 deletions onnx_diagnostic/reference/ort_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,11 @@ def run(
outputs = self._run_local(node, inputs, results)
else:
outputs = self._run(node, inputs, results)
for name, value in zip(node.output, outputs):
if name == "":
continue
node_output = [o for o in node.output if o]
assert len(node_output) == len(
outputs
), f"Length mismatch between node output={node.output} and outputs={outputs}"
for name, value in zip(node_output, outputs):
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
assert isinstance(name, str), f"unexpected type for name {type(name)}"
results[name] = value
Expand Down Expand Up @@ -384,6 +386,11 @@ def _make_model_proto(
onx = shi.infer_shapes(onx)
return onx

def _make_model_outputs(
self, node: NodeProto, inputs: List[ValueInfoProto]
) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]

@classmethod
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
"""
Expand Down Expand Up @@ -434,6 +441,7 @@ def _get_sess(
node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
)
]
prenodes = [] # type: ignore[var-annotated]
else:
unique_names = set()
vinputs = []
Expand All @@ -447,9 +455,9 @@ def _get_sess(
vinputs.append(value)

# no need to run shape inference
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
prenodes, voutputs = self._make_model_outputs(node, vinputs)

onx = self._make_model_proto([node], vinputs, voutputs)
onx = self._make_model_proto([*prenodes, node], vinputs, voutputs)
if node.op_type in {"Shape", "Size"}:
on_cpu = True

Expand Down
6 changes: 1 addition & 5 deletions onnx_diagnostic/torch_onnx/sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,7 @@ def _loop_cmp(
f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
)
else:
print(
f"[run_align-dx] discrepancies "
f"{string_diff(d, with_shape=True, with_device=True)} - "
f"[{to}/{o}]"
)
print(f"[run_align-dx] discrepancies {string_diff(d)} - [{to}/{o}]")
return (i, i_onnx, o, to, string_type(torch_results[to], **str_kws), d)
return None

Expand Down
Loading