Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
Adds code to turn onnx example into python unit test (#375)
Browse files Browse the repository at this point in the history
* Adds code to turn onnx example into python unit test

* Update test_onnx_backend.py
  • Loading branch information
sdpython committed Mar 5, 2022
1 parent 438fcb4 commit c839d7c
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 4 deletions.
128 changes: 124 additions & 4 deletions _unittests/ut_testing/test_onnx_backend.py
@@ -1,7 +1,11 @@
"""
@brief test log(time=10s)
@brief test log(time=40s)
"""
import unittest
from numpy import array, float32
from onnx.helper import (
make_model, make_node, set_model_props, make_graph,
make_tensor_value_info)
from pyquickhelper.pycode import ExtTestCase
from mlprodict.testing.onnx_backend import enumerate_onnx_tests
from mlprodict.onnxrt import OnnxInference
Expand Down Expand Up @@ -61,11 +65,127 @@ def test_enumerate_onnx_tests_run(self):
if __name__ == '__main__':
print(len(missed), len(failed), len(mismatch))
for t in failed:
print("failed", t[0])
print("failed", str(t[0]).replace('\\\\', '\\'))
for t in mismatch:
print("mismatch", t[0])
print("mismatch", str(t[0]).replace('\\\\', '\\'))
for t in missed:
print("missed", t[0])
print("missed", str(t[0]).replace('\\\\', '\\'))

def test_onnx_backend_test_to_python(self):
name = 'test_abs'
code = []
for te in enumerate_onnx_tests('node', lambda folder: folder == name):
code.append(te.to_python())
self.assertEqual(len(code), 1)
self.assertIn('def test_abs(self):', code[0])
self.assertIn('from onnx.helper', code[0])
self.assertIn('for y, gy in zip(ys, goty):', code[0])
if __name__ == '__main__':
print(code[0])

def test_abs(self):

def create_model():
'''
Converted ``test_abs``.
* producer: backend-test
* version: 0
* description:
'''

initializers = []
nodes = []
inputs = []
outputs = []

opsets = {'': 9}

value = make_tensor_value_info('x', 1, [3, 4, 5])
inputs.append(value)

value = make_tensor_value_info('y', 1, [3, 4, 5])
outputs.append(value)

node = make_node(
'Abs',
['x'],
['y'],
domain='')
nodes.append(node)

graph = make_graph(nodes, 'test_abs', inputs, outputs, initializers)

onnx_model = make_model(graph)
onnx_model.ir_version = 3
onnx_model.producer_name = 'backend-test'
onnx_model.producer_version = ''
onnx_model.domain = ''
onnx_model.model_version = 0
onnx_model.doc_string = ''
set_model_props(onnx_model, {})

del onnx_model.opset_import[:] # pylint: disable=E1101
for dom, value in opsets.items():
op_set = onnx_model.opset_import.add()
op_set.domain = dom
op_set.version = value

return onnx_model

onnx_model = create_model()

oinf = OnnxInference(onnx_model)
xs = [
array([[[1.7640524, 0.4001572, 0.978738, 2.2408931,
1.867558],
[-0.9772779, 0.95008844, -0.1513572, -0.10321885,
0.41059852],
[0.14404356, 1.4542735, 0.7610377, 0.12167501,
0.44386324],
[0.33367434, 1.4940791, -0.20515826, 0.3130677,
-0.85409576]],

[[-2.5529897, 0.6536186, 0.8644362, -0.742165,
2.2697546],
[-1.4543657, 0.04575852, -0.18718386, 1.5327792,
1.4693588],
[0.15494743, 0.37816253, -0.88778573, -1.9807965,
-0.34791216],
[0.15634897, 1.2302907, 1.2023798, -0.3873268,
-0.30230275]],

[[-1.048553, -1.420018, -1.7062702, 1.9507754,
-0.5096522],
[-0.4380743, -1.2527953, 0.7774904, -1.6138978,
-0.21274029],
[-0.89546657, 0.3869025, -0.51080513, -1.1806322,
-0.02818223],
[0.42833188, 0.06651722, 0.3024719, -0.6343221,
-0.36274117]]], dtype=float32),
]
ys = [
array([[[1.7640524, 0.4001572, 0.978738, 2.2408931, 1.867558],
[0.9772779, 0.95008844, 0.1513572, 0.10321885, 0.41059852],
[0.14404356, 1.4542735, 0.7610377, 0.12167501, 0.44386324],
[0.33367434, 1.4940791, 0.20515826, 0.3130677, 0.85409576]],

[[2.5529897, 0.6536186, 0.8644362, 0.742165, 2.2697546],
[1.4543657, 0.04575852, 0.18718386, 1.5327792, 1.4693588],
[0.15494743, 0.37816253, 0.88778573, 1.9807965, 0.34791216],
[0.15634897, 1.2302907, 1.2023798, 0.3873268, 0.30230275]],

[[1.048553, 1.420018, 1.7062702, 1.9507754, 0.5096522],
[0.4380743, 1.2527953, 0.7774904, 1.6138978, 0.21274029],
[0.89546657, 0.3869025, 0.51080513, 1.1806322, 0.02818223],
[0.42833188, 0.06651722, 0.3024719, 0.6343221, 0.36274117]]],
dtype=float32),
]
feeds = {n: x for n, x in zip(oinf.input_names, xs)}
got = oinf.run(feeds)
goty = [got[k] for k in oinf.output_names]
for y, gy in zip(ys, goty):
self.assertEqualArray(y, gy)


if __name__ == "__main__":
Expand Down
42 changes: 42 additions & 0 deletions mlprodict/testing/onnx_backend.py
Expand Up @@ -3,6 +3,7 @@
@brief Tests with onnx backend.
"""
import os
import textwrap
from numpy.testing import assert_almost_equal
import onnx
from onnx.numpy_helper import to_array
Expand Down Expand Up @@ -126,6 +127,47 @@ def run(self, load_fct, run_fct, index=None, decimal=5):
"Output %d of test %d in folder %r failed." % (
i, index, self.folder)) from ex

def to_python(self):
"""
Returns a python code equivalent to the ONNX test.
:return: code
"""
from ..onnx_tools.onnx_export import export2onnx
rows = []
code = export2onnx(self.onnx_model)
lines = code.split('\n')
lines = [line for line in lines
if not line.strip().startswith('print') and
not line.strip().startswith('# ')]
rows.append(textwrap.dedent("\n".join(lines)))
rows.append("oinf = OnnxInference(onnx_model)")
for test in self.tests:
rows.append("xs = [")
for inp in test['inputs']:
rows.append(textwrap.indent(repr(inp) + ',',
' ' * 2))
rows.append("]")
rows.append("ys = [")
for out in test['outputs']:
rows.append(textwrap.indent(repr(out) + ',',
' ' * 2))
rows.append("]")
rows.append("feeds = {n: x for n, x in zip(oinf.input_names, xs)}")
rows.append("got = oinf.run(feeds)")
rows.append("goty = [got[k] for k in oinf.output_names]")
rows.append("for y, gy in zip(ys, goty):")
rows.append(" self.assertEqualArray(y, gy)")
rows.append("")
code = "\n".join(rows)
final = "\n".join(["def %s(self):" % self.name,
textwrap.indent(code, ' ')])
try:
from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8
except ImportError:
return final
return remove_extra_spaces_and_pep8(final)


def enumerate_onnx_tests(series, fct_filter=None):
"""
Expand Down

0 comments on commit c839d7c

Please sign in to comment.