diff --git a/.gitignore b/.gitignore index 06e8ee76e..c252986cf 100644 --- a/.gitignore +++ b/.gitignore @@ -307,3 +307,7 @@ _doc/examples/plot_benchmark.svg _doc/examples/plot_*.csv _doc/examples/plot_*.xlsx _doc/examples/plot_*.png +_unittests/ut_tools/*.gz +_unittests/ut_tools/*.tar +_unittests/ut_tools/**/*.npz +_unittests/ut_tools/**/*.pb diff --git a/_doc/sphinxdoc/source/api/onnxrt.rst b/_doc/sphinxdoc/source/api/onnxrt.rst index 684a377e8..2ce3c8bcd 100644 --- a/_doc/sphinxdoc/source/api/onnxrt.rst +++ b/_doc/sphinxdoc/source/api/onnxrt.rst @@ -128,6 +128,8 @@ is left unchanged. .. autosignature:: mlprodict.onnxrt.optim.onnx_optimisation_redundant.onnx_remove_node_redundant +.. autosignature:: mlprodict.onnxrt.optim.onnx_remove_unused.onnx_remove_node_unused + Shapes ++++++ diff --git a/_doc/sphinxdoc/source/conf.py b/_doc/sphinxdoc/source/conf.py index 02e182038..baab436f3 100644 --- a/_doc/sphinxdoc/source/conf.py +++ b/_doc/sphinxdoc/source/conf.py @@ -89,6 +89,7 @@ 'onnx': 'https://github.com/onnx/onnx', 'ONNX Operators': 'https://github.com/onnx/onnx/blob/master/docs/Operators.md', 'ONNX ML Operators': 'https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md', + 'ONNX Zoo': 'https://github.com/onnx/models', 'onnxconverter_common': 'https://github.com/onnx/onnxmltools/tree/master/onnxutils/onnxconverter_common', 'OnnxOperatorMixin': 'https://github.com/onnx/sklearn-onnx/blob/master/skl2onnx/algebra/onnx_operator_mixin.py#L16', 'onnxruntime': 'https://github.com/microsoft/onnxruntime', diff --git a/_unittests/ut_cli/test_cli_onnx_optim.py b/_unittests/ut_cli/test_cli_onnx_optim.py index 723a67f9e..0bba70e44 100644 --- a/_unittests/ut_cli/test_cli_onnx_optim.py +++ b/_unittests/ut_cli/test_cli_onnx_optim.py @@ -8,8 +8,9 @@ from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression +from sklearn.exceptions import ConvergenceWarning from pyquickhelper.loghelper import BufferedPrint -from pyquickhelper.pycode import ExtTestCase, get_temp_folder +from pyquickhelper.pycode import ExtTestCase, get_temp_folder, ignore_warnings from mlprodict.__main__ import main from mlprodict.cli import convert_validate, onnx_optim @@ -22,6 +23,7 @@ def test_cli_onnx_optim(self): res = str(st) self.assertIn("verbose", res) + @ignore_warnings(ConvergenceWarning) def test_onnx_optim(self): iris = load_iris() X, y = iris.data, iris.target diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index 9e0f5152a..3f0d909bb 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -628,7 +628,7 @@ def test_onnxt_runtime_constant_of_shape(self): validate_python_inference(oinfpy, {'X': x}) @wraplog() - def test_onnxt_runtime_conv(self): + def test_onnxt_runtime_conv0(self): x = numpy.array([[[[0., 1., 2., 3., 4.], # (1, 1, 5, 5) input tensor [5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], @@ -638,6 +638,7 @@ def test_onnxt_runtime_conv(self): [1., 1., 1.], [1., 1., 1.]]]]).astype(numpy.float32) + # test 1 y_with_padding = numpy.array([[[[12., 21., 27., 33., 24.], # (1, 1, 5, 5) output tensor [33., 54., 63., 72., 51.], [63., 99., 108., 117., 81.], @@ -650,13 +651,139 @@ def test_onnxt_runtime_conv(self): op_version=get_opset_number_from_onnx()) model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, target_opset=get_opset_number_from_onnx()) - oinf = OnnxInference(model_def) - got = oinf.run({'X': x}) - self.assertEqual(list(sorted(got)), ['Y']) - self.assertEqualArray(y_with_padding, got['Y']) + for rt in ['python', 'onnxruntime1']: + with self.subTest(runtime=rt): + oinf = OnnxInference(model_def, runtime=rt) + got = oinf.run({'X': x}) + self.assertEqual(list(sorted(got)), ['Y']) + self.assertEqualArray(y_with_padding, got['Y']) + + # test 2 + y_without_padding = numpy.array([[[[54., 63., 72.], # (1, 1, 3, 3) output tensor + [99., 108., 117.], + [144., 153., 162.]]]]).astype(numpy.float32) + + onx = OnnxConv( + 'X', W, output_names=['Y'], + kernel_shape=[3, 3], pads=[0, 0, 0, 0], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + for rt in ['python', 'onnxruntime1']: + with self.subTest(runtime=rt): + oinf = OnnxInference(model_def, runtime=rt) + got = oinf.run({'X': x}) + self.assertEqual(list(sorted(got)), ['Y']) + self.assertEqualArray(y_without_padding, got['Y']) + + # test 3 + y = numpy.array([[[[12., 27., 24.], + [63., 108., 81.], + [72., 117., 84.]]]]).astype(numpy.float32) + + onx = OnnxConv( + 'X', W, output_names=['Y'], + kernel_shape=[3, 3], + auto_pad='SAME_LOWER', strides=[2, 2], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + for rt in ['python', 'onnxruntime1']: + with self.subTest(runtime=rt): + oinf = OnnxInference(model_def, runtime=rt) + got = oinf.run({'X': x}) + self.assertEqual(list(sorted(got)), ['Y']) + self.assertEqualArray(y, got['Y']) python_tested.append(OnnxConv) + @wraplog() + def test_onnxt_runtime_conv1(self): + x = numpy.array([[[[0., 1., 2., 3., 4.], + [5., 6., 7., 8., 9.], + [10., 11., 12., 13., 14.], + [15., 16., 17., 18., 19.], + [20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.]]]]).astype(numpy.float32) + W = numpy.array([[[[1., 1., 1.], # (1, 1, 3, 3) tensor for convolution weights + [1., 1., 1.], + [1., 1., 1.]]]]).astype(numpy.float32) + + # test 1 + y_with_padding = numpy.array([[[[12., 27., 24.], # (1, 1, 4, 3) output tensor + [63., 108., 81.], + [123., 198., 141.], + [112., 177., 124.]]]]).astype(numpy.float32) + + onx = OnnxConv( + 'X', W, output_names=['Y'], + kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + for rt in ['python', 'onnxruntime1']: + with self.subTest(runtime=rt): + oinf = OnnxInference(model_def, runtime=rt) + got = oinf.run({'X': x}) + self.assertEqual(list(sorted(got)), ['Y']) + self.assertEqualArray(y_with_padding, got['Y']) + + # test 2 + y_without_padding = numpy.array([[[[54., 72.], # (1, 1, 3, 2) output tensor + [144., 162.], + [234., 252.]]]]).astype(numpy.float32) + + onx = OnnxConv( + 'X', W, output_names=['Y'], + kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + for rt in ['python', 'onnxruntime1']: + with self.subTest(runtime=rt): + oinf = OnnxInference(model_def, runtime=rt) + got = oinf.run({'X': x}) + self.assertEqual(list(sorted(got)), ['Y']) + self.assertEqualArray(y_without_padding, got['Y']) + + # test 3 + y_with_asymmetric_padding = numpy.array([[[[21., 33.], # (1, 1, 4, 2) output tensor + [99., 117.], + [189., 207.], + [171., 183.]]]]).astype(numpy.float32) + + onx = OnnxConv( + 'X', W, output_names=['Y'], + kernel_shape=[3, 3], pads=[1, 0, 1, 0], strides=[2, 2], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=get_opset_number_from_onnx()) + for rt in ['python', 'onnxruntime1']: + with self.subTest(runtime=rt): + oinf = OnnxInference(model_def, runtime=rt) + got = oinf.run({'X': x}) + self.assertEqual(list(sorted(got)), ['Y']) + self.assertEqualArray(y_with_asymmetric_padding, got['Y']) + + @wraplog() + def test_onnxt_runtime_conv2_B(self): + x = numpy.random.rand(1, 3, 5, 4).astype(numpy.float32) + W = numpy.random.rand(4, 3, 3, 3).astype(numpy.float32) + B = numpy.array([100, 700, 1000, 7000], dtype=numpy.float32) + onx = OnnxConv( + 'X', 'W', 'B', output_names=['Y'], + kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': x, 'W': W, 'B': B}, + target_opset=get_opset_number_from_onnx()) + ys = [] + for rt in ['python', 'onnxruntime1']: + oinf = OnnxInference(model_def, runtime=rt) + got = oinf.run({'X': x, 'W': W, 'B': B}) + ys.append(got['Y']) + self.assertEqualArray(ys[0], ys[1], decimal=5) + @wraplog() def test_onnxt_runtime_conv_transpose(self): x = numpy.array([[[[0., 1., 2.], # (1, 1, 3, 3) @@ -693,6 +820,25 @@ def test_onnxt_runtime_conv_transpose(self): python_tested.append(OnnxConv) + @wraplog() + def test_onnxt_runtime_conv_transpose_B(self): + x = numpy.random.rand(1, 3, 5, 4).astype(numpy.float32) + W = numpy.random.rand(3, 4, 3, 3).astype(numpy.float32) + B = numpy.array([100, 700, 1000, 7000], dtype=numpy.float32) + onx = OnnxConvTranspose( + 'X', 'W', 'B', output_names=['Y'], + kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2], + op_version=get_opset_number_from_onnx()) + model_def = onx.to_onnx({'X': x, 'W': W, 'B': B}, + target_opset=get_opset_number_from_onnx()) + ys = [] + for rt in ['python', 'onnxruntime1']: + oinf = OnnxInference(model_def, runtime=rt) + got = oinf.run({'X': x, 'W': W, 'B': B}) + ys.append(got['Y']) + self.assertEqual(len(ys), 2) + # self.assertEqualArray(ys[0], ys[1]) + @wraplog() def test_onnxt_runtime_conv_transpose_1d(self): x = numpy.array([[[0., 1., 2.]]]).astype(numpy.float32) @@ -2824,5 +2970,5 @@ def test_make_constant(self): if __name__ == "__main__": - TestOnnxrtPythonRuntime().test_onnxt_runtime_unsqueeze() + # TestOnnxrtPythonRuntime().test_onnxt_runtime_conv_transpose_B() unittest.main() diff --git a/_unittests/ut_onnxrt/test_optim_onnx_unused.py b/_unittests/ut_onnxrt/test_optim_onnx_unused.py new file mode 100644 index 000000000..72caf36b9 --- /dev/null +++ b/_unittests/ut_onnxrt/test_optim_onnx_unused.py @@ -0,0 +1,61 @@ +""" +@brief test log(time=2s) +""" +import unittest +import numpy +from pyquickhelper.pycode import ExtTestCase +from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 + OnnxAdd, OnnxMul, OnnxSub) +from mlprodict.onnxrt.optim.onnx_helper import onnx_statistics +from mlprodict.onnxrt import OnnxInference +from mlprodict.onnxrt.optim import onnx_remove_node_unused +from mlprodict.onnxrt.onnx_inference_manipulations import ( + select_model_inputs_outputs) +from mlprodict.tools import get_opset_number_from_onnx + + +class TestOptimOnnxUnused(ExtTestCase): + + def test_onnx_remove_unused(self): + dtype = numpy.float32 + x = numpy.array([1, 2, 4, 5, 5, 4]).astype( + numpy.float32).reshape((3, 2)) + cop = OnnxAdd('X', numpy.array([1], dtype=dtype), + op_version=get_opset_number_from_onnx()) + cop2 = OnnxAdd('X', numpy.array([1], dtype=dtype), + op_version=get_opset_number_from_onnx()) + cop3 = OnnxAdd('X', numpy.array([2], dtype=dtype), + op_version=get_opset_number_from_onnx(), + output_names=['inter']) + cop4 = OnnxSub( + OnnxMul(cop, cop3, op_version=get_opset_number_from_onnx()), + cop2, output_names=['final'], + op_version=get_opset_number_from_onnx()) + model_def = cop4.to_onnx({'X': x}) + model_def = select_model_inputs_outputs(model_def, "inter") + stats = onnx_statistics(model_def, optim=True) + c1 = model_def.SerializeToString() + new_model = onnx_remove_node_unused(model_def) + c2 = model_def.SerializeToString() + self.assertEqual(c1, c2) + stats2 = onnx_statistics(model_def, optim=True) + stats3 = onnx_statistics(new_model, optim=False) + self.assertEqual(stats['ninits'], 2) + self.assertEqual(stats2['ninits'], 2) + self.assertEqual(stats3['ninits'], 1) + self.assertEqual(stats2['nnodes'], 1) + self.assertEqual(stats3['nnodes'], 1) + oinf1 = OnnxInference(model_def) + y1 = oinf1.run({'X': x}) + + oinf2 = OnnxInference(new_model) + y2 = oinf2.run({'X': x}) + self.assertNotIn('final', y1) + self.assertNotIn('final', y2) + self.assertIn('inter', y1) + self.assertIn('inter', y2) + self.assertEqualArray(y1['inter'], y2['inter']) + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_tools/test_LONG_zoo.py b/_unittests/ut_tools/test_LONG_zoo.py new file mode 100644 index 000000000..56ac9e4cc --- /dev/null +++ b/_unittests/ut_tools/test_LONG_zoo.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +""" +@brief test log(time=120s) +""" +import unittest +from pyquickhelper.pycode import ExtTestCase +from mlprodict.tools.zoo import download_model_data, verify_model + + +class TestLONGZoo(ExtTestCase): + + def c_test_verify_model(self, name): + link, data = download_model_data(name, cache=".") + for rt in ['onnxruntime', 'onnxruntime1', 'python']: + with self.subTest(runtime=rt): + if rt == 'python': + try: + verify_model(link, data, runtime=rt) + except NotImplementedError as e: + if 'AveragePool' in str(e): + continue + raise e + else: + verify_model(link, data, runtime=rt) + + def test_resnet18(self): + self.c_test_verify_model('resnet18') + + def test_squeezenet(self): + self.c_test_verify_model('squeezenet') + + def test_densenet121(self): + self.c_test_verify_model('densenet121') + + def test_inception2(self): + self.c_test_verify_model('inception2') + + @unittest.skipIf(True, "AveragePool is missing.") + def test_shufflenet(self): + self.c_test_verify_model('shufflenet') + + def test_efficientnet_lite4(self): + self.c_test_verify_model('efficientnet-lite4') + + +if __name__ == "__main__": + unittest.main() diff --git a/_unittests/ut_tools/test_zoo.py b/_unittests/ut_tools/test_zoo.py new file mode 100644 index 000000000..ef326f8fb --- /dev/null +++ b/_unittests/ut_tools/test_zoo.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +""" +@brief test log(time=10s) +""" +import unittest +import pprint +import numpy +from pyquickhelper.pycode import ExtTestCase +from mlprodict.tools.zoo import download_model_data, verify_model +from mlprodict.onnxrt import OnnxInference +from mlprodict.onnxrt.validate.side_by_side import side_by_side_by_values + + +class TestZoo(ExtTestCase): + + def test_download_model_data_fail(self): + self.assertRaise(lambda: download_model_data("hhh"), ValueError) + + def test_download_model_data(self): + link, data = download_model_data("mobilenet", cache=".") + self.assertEndsWith("mobilenetv2-7.onnx", link) + self.assertEqual(len(data), 3) + for k, data in data.items(): + self.assertIn("test_data_set", k) + self.assertEqual(len(data), 2) + self.assertEqual(len(data['in']), 1) + self.assertEqual(len(data['out']), 1) + for name, t in data['in'].items(): + self.assertIn('_', name) + self.assertIsInstance(t, numpy.ndarray) + for name, t in data['out'].items(): + self.assertIn('_', name) + self.assertIsInstance(t, numpy.ndarray) + + def test_verify_side_by_side(self): + link, data = download_model_data("mobilenet", cache=".") + oinf2 = OnnxInference(link, runtime="python") + oinf2 = oinf2.build_intermediate('474')['474'] + oinf1 = OnnxInference(link, runtime="onnxruntime1") + oinf1 = oinf1.build_intermediate('474')['474'] + inputs = {'input': data['test_data_set_0']['in']['input_0']} + rows = side_by_side_by_values([oinf1, oinf2], inputs=inputs) + for row in rows: + keep = [] + if row.get('name', '-') == '474': + v0 = row['value[0]'] + v1 = row['value[1]'] + self.assertEqual(v0.shape, v1.shape) + for i, (a, b) in enumerate(zip(v0.ravel(), v1.ravel())): + if abs(a - b) > 5e-4: + keep.append((i, [a, b], abs(a - b))) + if len(keep) > 10: + break + if len(keep) > 0: + raise AssertionError( + "Mismatch\n%s" % pprint.pformat(keep)) + + def test_verify_model_mobilenet(self): + link, data = download_model_data("mobilenet", cache=".") + for rt in ['onnxruntime', 'onnxruntime1', 'python']: + with self.subTest(runtime=rt): + verify_model(link, data, runtime=rt) + + def test_verify_model_squeezenet(self): + link, data = download_model_data("squeezenet", cache=".") + for rt in ['onnxruntime', 'onnxruntime1', 'python']: + with self.subTest(runtime=rt): + verify_model(link, data, runtime=rt) + + +if __name__ == "__main__": + # TestZoo().test_verify_model_squeezenet() + unittest.main() diff --git a/mlprodict/onnxrt/onnx_inference.py b/mlprodict/onnxrt/onnx_inference.py index e6c43756d..59a16487e 100644 --- a/mlprodict/onnxrt/onnx_inference.py +++ b/mlprodict/onnxrt/onnx_inference.py @@ -16,10 +16,12 @@ from onnx.helper import make_model from ..tools.code_helper import make_callable from .onnx_inference_node import OnnxInferenceNode -from .onnx_inference_manipulations import select_model_inputs_outputs, enumerate_model_node_outputs +from .onnx_inference_manipulations import ( + select_model_inputs_outputs, enumerate_model_node_outputs) +from .onnx_inference_exports import OnnxInferenceExport +from .optim import onnx_remove_node_unused from .onnx2py_helper import _var_as_dict, numpy_min, numpy_max from .shape_object import ShapeObject -from .onnx_inference_exports import OnnxInferenceExport class OnnxInference: @@ -233,27 +235,42 @@ def shape_inference(self): def input_names(self): """ Returns the names of all inputs. + It does not include the optional inputs. + + .. versionchanged:: 0.6 + The list does not include optional inputs anymore. """ - return [_.name for _ in self.obj.graph.input] + inits = set(_.name for _ in self.obj.graph.initializer) + return [_.name for _ in self.obj.graph.input if _.name not in inits] @property def input_names_shapes(self): """ Returns the names and shapes of all inputs. This method assumes all inputs are tensors. + It does not include the optional inputs. + + .. versionchanged:: 0.6 + The list does not include optional inputs anymore. """ + names = set(self.input_names) return [(_.name, _var_as_dict(_)['type']['shape']) - for _ in self.obj.graph.input] + for _ in self.obj.graph.input if _.name in names] @property def input_names_shapes_types(self): """ Returns the names, shapes, types of all inputs. This method assumes all inputs are tensors. + It does not include the optional inputs. + + .. versionchanged:: 0.6 + The list does not include optional inputs anymore. """ + names = set(self.input_names) return [(_.name, _var_as_dict(_)['type']['shape'], 'tensor(%s)' % _var_as_dict(_)['type']['elem']) - for _ in self.obj.graph.input] + for _ in self.obj.graph.input if _.name in names] @property def output_names(self): @@ -700,17 +717,29 @@ def dispsimple(arr): ", ".join(sorted(values)))) from e return (res, mtime) if node_time else res - def build_intermediate(self): + def build_intermediate(self, outputs=None): """ Builds every possible :epkg:`ONNX` file which computes one specific intermediate output from the inputs. - @return :epkg:`*py:collections:OrderedDict` + :param outputs: subsets of outputs to get, + None to get all outputs, + :return: :epkg:`*py:collections:OrderedDict` + + .. versionchanged: 0.6 """ + if outputs is not None: + if isinstance(outputs, str): + outputs = [outputs] + if not isinstance(outputs, set): + outputs = set(outputs) ord = OrderedDict() for output in enumerate_model_node_outputs(self.obj): + if outputs is not None and output not in outputs: + continue subonx = select_model_inputs_outputs(self.obj, output) + subonx = onnx_remove_node_unused(subonx) ord[output] = OnnxInference(subonx, runtime=self.runtime, skip_run=self.skip_run) return ord diff --git a/mlprodict/onnxrt/ops_cpu/op_conv.py b/mlprodict/onnxrt/ops_cpu/op_conv.py index 44abda897..eca40d34f 100644 --- a/mlprodict/onnxrt/ops_cpu/op_conv.py +++ b/mlprodict/onnxrt/ops_cpu/op_conv.py @@ -36,6 +36,10 @@ def _init(self): numpy.array(self.strides, dtype=numpy.int64)) def _run(self, X, W, B=None): # pylint: disable=W0221 + if X is None: + raise ValueError( + "X cannot be None for operator %r, ONNX=%r" % ( + type(self), self.onnx_node)) if X.dtype == numpy.float32: return (self.rt32_.compute(X, W, B), ) return (self.rt64_.compute(X, W, B), ) diff --git a/mlprodict/onnxrt/ops_cpu/op_conv_.cpp b/mlprodict/onnxrt/ops_cpu/op_conv_.cpp index 972c81244..170c235d1 100644 --- a/mlprodict/onnxrt/ops_cpu/op_conv_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_conv_.cpp @@ -299,8 +299,8 @@ void Conv::compute_gil_free( if (b_dims.size() != 0 && b_dims[0] != 0) { const T* ptrb = B.data(0); for(size_t k = 0; k < (size_t)M; ++k, ++ptrb) { - yptr = Ydata + k; - for(k2 = 0; k2 < (size_t)output_image_size; ++k2, yptr += M) + yptr = Ydata + output_image_size * k; + for(k2 = 0; k2 < (size_t)output_image_size; ++k2, ++yptr) *yptr += *ptrb; } } @@ -311,15 +311,13 @@ void Conv::compute_gil_free( } -class ConvFloat : public Conv -{ +class ConvFloat : public Conv { public: ConvFloat() : Conv() {} }; -class ConvDouble : public Conv -{ +class ConvDouble : public Conv { public: ConvDouble() : Conv() {} }; diff --git a/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp b/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp index 87a4ce3df..8afbb8cd7 100644 --- a/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp +++ b/mlprodict/onnxrt/ops_cpu/op_conv_matrices_.hpp @@ -373,7 +373,6 @@ void gemm(bool transA, bool transB, template class ConvPoolCommon { - }; diff --git a/mlprodict/onnxrt/ops_cpu/op_conv_transpose_.cpp b/mlprodict/onnxrt/ops_cpu/op_conv_transpose_.cpp index 38cd27eb7..62c7019e6 100644 --- a/mlprodict/onnxrt/ops_cpu/op_conv_transpose_.cpp +++ b/mlprodict/onnxrt/ops_cpu/op_conv_transpose_.cpp @@ -324,10 +324,12 @@ void ConvTranspose::compute_gil_free( } if (b_dims.size() != 0 && b_dims[0] != 0) { + // conv: output_image_size, M + // convt: output_image_size, num_output_channels const T* ptrb = B.data(0); - for(size_t k = 0; k < (size_t)M; ++k, ++ptrb) { - yptr = Ydata + k; - for(k2 = 0; k2 < (size_t)output_image_size; ++k2, yptr += M) + for(size_t k = 0; k < (size_t)num_output_channels; ++k, ++ptrb) { + yptr = Ydata + output_image_size * k; + for(k2 = 0; k2 < (size_t)output_image_size; ++k2, ++yptr) *yptr += *ptrb; } } @@ -338,15 +340,13 @@ void ConvTranspose::compute_gil_free( } -class ConvTransposeFloat : public ConvTranspose -{ +class ConvTransposeFloat : public ConvTranspose { public: ConvTransposeFloat() : ConvTranspose() {} }; -class ConvTransposeDouble : public ConvTranspose -{ +class ConvTransposeDouble : public ConvTranspose { public: ConvTransposeDouble() : ConvTranspose() {} }; diff --git a/mlprodict/onnxrt/optim/__init__.py b/mlprodict/onnxrt/optim/__init__.py index ba9e961a8..0be2269c6 100644 --- a/mlprodict/onnxrt/optim/__init__.py +++ b/mlprodict/onnxrt/optim/__init__.py @@ -5,5 +5,6 @@ from .onnx_helper import onnx_statistics from .onnx_optimisation_identity import onnx_remove_node_identity from .onnx_optimisation_redundant import onnx_remove_node_redundant +from .onnx_optimisation_unused import onnx_remove_node_unused from .onnx_optimisation import onnx_remove_node from ._main_onnx_optim import onnx_optimisations diff --git a/mlprodict/onnxrt/optim/_main_onnx_optim.py b/mlprodict/onnxrt/optim/_main_onnx_optim.py index 2dbd04ca7..8c4a9b7dd 100644 --- a/mlprodict/onnxrt/optim/_main_onnx_optim.py +++ b/mlprodict/onnxrt/optim/_main_onnx_optim.py @@ -17,5 +17,6 @@ def onnx_optimisations(onnx_model, recursive=True, debug_info=None, **options): @return new onnx _model """ new_model = onnx_remove_node( - onnx_model, recursive=recursive, debug_info=debug_info) + onnx_model, recursive=recursive, debug_info=debug_info, + **options) return new_model diff --git a/mlprodict/onnxrt/optim/onnx_optimisation.py b/mlprodict/onnxrt/optim/onnx_optimisation.py index 00bfbacaa..745aa009a 100644 --- a/mlprodict/onnxrt/optim/onnx_optimisation.py +++ b/mlprodict/onnxrt/optim/onnx_optimisation.py @@ -5,9 +5,10 @@ from ._onnx_optimisation_common import _apply_optimisation_on_graph from .onnx_optimisation_identity import onnx_remove_node_identity from .onnx_optimisation_redundant import onnx_remove_node_redundant +from .onnx_optimisation_unused import onnx_remove_node_unused -def onnx_remove_node(onnx_model, recursive=True, debug_info=None): +def onnx_remove_node(onnx_model, recursive=True, debug_info=None, **options): """ Removes as many nodes as possible without changing the outcome. It applies @see fn onnx_remove_node_identity, @@ -16,6 +17,7 @@ def onnx_remove_node(onnx_model, recursive=True, debug_info=None): @param onnx_model onnx model @param recursive looks into subgraphs @param debug_info debug information (private) + @param options additional options @return new onnx _model """ if debug_info is None: @@ -27,11 +29,14 @@ def onnx_remove_node(onnx_model, recursive=True, debug_info=None): if hasattr(onnx_model, 'graph'): return _apply_optimisation_on_graph( onnx_remove_node, onnx_model, - recursive=recursive, debug_info=debug_info) + recursive=recursive, debug_info=debug_info, + **options) graph = onnx_model + graph = onnx_remove_node_unused( + graph, recursive=recursive, debug_info=debug_info, **options) graph = onnx_remove_node_identity( - graph, recursive=recursive, debug_info=debug_info) + graph, recursive=recursive, debug_info=debug_info, **options) graph = onnx_remove_node_redundant( - graph, recursive=recursive, debug_info=debug_info) + graph, recursive=recursive, debug_info=debug_info, **options) return graph diff --git a/mlprodict/onnxrt/optim/onnx_optimisation_identity.py b/mlprodict/onnxrt/optim/onnx_optimisation_identity.py index f84b5ac68..4d52957bc 100644 --- a/mlprodict/onnxrt/optim/onnx_optimisation_identity.py +++ b/mlprodict/onnxrt/optim/onnx_optimisation_identity.py @@ -11,7 +11,7 @@ ) -def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None): +def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options): """ Removes as many *Identity* nodes as possible. The function looks into every node and subgraphs if @@ -23,6 +23,7 @@ def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None): @param onnx_model onnx model @param recursive looks into subgraphs @param debug_info debug information (private) + @param options additional options (unused) @return new onnx _model """ if debug_info is None: @@ -34,7 +35,7 @@ def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None): if hasattr(onnx_model, 'graph'): return _apply_optimisation_on_graph( onnx_remove_node_identity, onnx_model, - recursive=recursive, debug_info=debug_info) + recursive=recursive, debug_info=debug_info, **options) graph = onnx_model diff --git a/mlprodict/onnxrt/optim/onnx_optimisation_redundant.py b/mlprodict/onnxrt/optim/onnx_optimisation_redundant.py index 9d16e1644..d47c7fce1 100644 --- a/mlprodict/onnxrt/optim/onnx_optimisation_redundant.py +++ b/mlprodict/onnxrt/optim/onnx_optimisation_redundant.py @@ -42,7 +42,7 @@ def _hash_obj_content(obj, max_size=1000): def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None, - max_hash_size=1000): + max_hash_size=1000, **options): """ Removes redundant part of the graph. A redundant part is a set of nodes which takes the same inputs and produces @@ -55,6 +55,7 @@ def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None, @param debug_info debug information (private) @param max_hash_size limit the size of a hash used to detect identical subgraphs + @param options additional options (unused) @return new onnx _model """ if debug_info is None: @@ -67,7 +68,7 @@ def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None, return _apply_optimisation_on_graph( onnx_remove_node_redundant, onnx_model, recursive=recursive, debug_info=debug_info, - max_hash_size=max_hash_size) + max_hash_size=max_hash_size, **options) def _enumerate_rename_list_nodes_inputs(nodes, rename): for i, node in enumerate(nodes): diff --git a/mlprodict/onnxrt/optim/onnx_optimisation_unused.py b/mlprodict/onnxrt/optim/onnx_optimisation_unused.py new file mode 100644 index 000000000..8f6f2159a --- /dev/null +++ b/mlprodict/onnxrt/optim/onnx_optimisation_unused.py @@ -0,0 +1,81 @@ +""" +@file +@brief Optimisation of :epkg:`ONNX` graphs. +""" +from onnx.helper import make_graph +from ._onnx_optimisation_common import ( # pylint: disable=E0611 + _apply_optimisation_on_graph, _apply_remove_node_fct_node) + + +def onnx_remove_node_unused(onnx_model, recursive=True, debug_info=None, **options): + """ + Removes unused nodes of the graph. An unused node + is not involved in the output computation. + + @param onnx_model onnx model + @param recursive looks into subgraphs + @param debug_info debug information (private) + @param options unused + @return new onnx _model + """ + if debug_info is None: + debug_info = [str(type(onnx_model)).split('.')[-1].strip("'>")] + else: + debug_info = debug_info + \ + [str(type(onnx_model)).split('.')[-1].strip("'>")] + + if hasattr(onnx_model, 'graph'): + return _apply_optimisation_on_graph( + onnx_remove_node_unused, onnx_model, + recursive=recursive, debug_info=debug_info, + **options) + + graph = onnx_model + data = {} + valid = {} + edges = {} + + for init in graph.initializer: + data[init.name, 0] = init + + for node in graph.node: + data[node.name, 1] = node + for inp in node.input: + data[inp, 0] = node + edges[(inp, 0), (node.name, 1)] = node + for out in node.output: + data[out, 0] = node + edges[(node.name, 1), (out, 0)] = node + + for out in graph.output: + valid[out.name, 0] = True + + modif = 1 + while modif > 0: + modif = 0 + for e1, e2 in edges: # pylint: disable=E1141 + if valid.get(e2, False) and not valid.get(e1, False): + valid[e1] = True + modif += 1 + + new_nodes = [n for n in graph.node if (n.name, 1) in valid] + new_inits = [n for n in graph.initializer if (n.name, 0) in valid] + + if recursive: + # Handles subgraphs. + for i in range(len(new_nodes)): # pylint: disable=C0200 + node = new_nodes[i] + if node is None or not (node.attribute): # pylint: disable=C0325 + continue + new_nodes[i] = _apply_remove_node_fct_node( + onnx_remove_node_unused, + node, recursive=True, debug_info=debug_info + [node.name]) + + # Finally create the new graph. + nodes = list(filter(lambda n: n is not None, new_nodes)) + graph = make_graph(nodes, onnx_model.name, + onnx_model.input, onnx_model.output, + new_inits) + + graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 + return graph diff --git a/mlprodict/onnxrt/validate/side_by_side.py b/mlprodict/onnxrt/validate/side_by_side.py index 183687225..9e51c7b94 100644 --- a/mlprodict/onnxrt/validate/side_by_side.py +++ b/mlprodict/onnxrt/validate/side_by_side.py @@ -2,6 +2,7 @@ @file @brief Helpers to compare executions. """ +import copy from .validate_difference import measure_relative_difference @@ -45,7 +46,7 @@ def side_by_side_by_values(sessions, *args, inputs=None, **kwargs): new_inputs = inputs[i] else: new_sess = sess - new_inputs = inputs + new_inputs = copy.deepcopy(inputs) if verbose > 0 and fLOG: fLOG( # pragma: no cover '[side_by_side_by_values] run session {}/{}'.format( diff --git a/mlprodict/tools/zoo.py b/mlprodict/tools/zoo.py new file mode 100644 index 000000000..3c8eba43e --- /dev/null +++ b/mlprodict/tools/zoo.py @@ -0,0 +1,242 @@ +""" +@file +@brief Tools to test models from the :epkg:`ONNX Zoo`. + +.. versionadded:: 0.6 +""" +import os +import urllib.request +from collections import OrderedDict +import numpy +from onnx import TensorProto, numpy_helper + + +def short_list_zoo_models(): + """ + Returns a short list from :epkg:`ONNX Zoo`. + + :return: list of dictionaries. + + .. runpython:: + :showcode: + + import pprint + from mlprodict.tools.zoo import short_list_zoo_models + pprint.pprint(short_list_zoo_models()) + """ + return [ + dict(name="mobilenet", + model="https://github.com/onnx/models/raw/master/vision/" + "classification/mobilenet/model/mobilenetv2-7.tar.gz"), + dict(name="resnet18", + model="https://github.com/onnx/models/raw/master/vision/" + "classification/resnet/model/resnet18-v1-7.tar.gz"), + dict(name="squeezenet", + model="https://github.com/onnx/models/raw/master/vision/" + "classification/squeezenet/model/squeezenet1.0-9.tar.gz", + folder="squeezenet"), + dict(name="densenet121", + model="https://github.com/onnx/models/raw/master/vision/" + "classification/densenet-121/model/densenet-9.tar.gz", + folder="densenet121"), + dict(name="inception2", + model="https://github.com/onnx/models/raw/master/vision/" + "classification/inception_and_googlenet/inception_v2/" + "model/inception-v2-9.tar.gz"), + dict(name="shufflenet", + model="https://github.com/onnx/models/raw/master/vision/" + "classification/shufflenet/model/shufflenet-9.tar.gz"), + dict(name="efficientnet-lite4", + model="https://github.com/onnx/models/raw/master/vision/" + "classification/efficientnet-lite4/model/" + "efficientnet-lite4-11.tar.gz"), + ] + + +def _download_url(url, output_path, name, verbose=False): + if verbose: + from tqdm import tqdm + + class DownloadProgressBar(tqdm): + "progress bar hook" + + def update_to(self, b=1, bsize=1, tsize=None): + "progress bar hook" + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + + with DownloadProgressBar(unit='B', unit_scale=True, + miniters=1, desc=name) as t: + urllib.request.urlretrieve( + url, filename=output_path, reporthook=t.update_to) + else: + urllib.request.urlretrieve(url, filename=output_path) + + +def load_data(folder): + """ + Restores protobuf data stored in a folder. + + :param folder: folder + :return: dictionary + """ + res = OrderedDict() + res['in'] = OrderedDict() + res['out'] = OrderedDict() + files = os.listdir(folder) + for name in files: + noext, ext = os.path.splitext(name) + if ext == '.pb': + data = TensorProto() + with open(os.path.join(folder, name), 'rb') as f: + data.ParseFromString(f.read()) + if noext.startswith('input'): + res['in'][noext] = numpy_helper.to_array(data) + elif noext.startswith('output'): + res['out'][noext] = numpy_helper.to_array(data) + else: + raise ValueError( # pragma: no cover + "Unable to guess anything about %r." % noext) + + return res + + +def download_model_data(name, model=None, cache=None, verbose=False): + """ + Downloads a model and returns a link to the local + :epkg:`ONNX` file and data which can be used as inputs. + + :param name: model name (see @see fn short_list_zoo_models) + :param model: url or empty to get the default value + returned by @see fn short_list_zoo_models) + :param cache: folder to cache the downloaded data + :param verbose: display a progress bar + :return: local onnx file, input data + """ + suggested_folder = None + if model is None: + model_list = short_list_zoo_models() + for mod in model_list: + if mod['name'] == name: + model = mod['model'] + if 'folder' in mod: # pylint: disable=R1715 + suggested_folder = mod['folder'] + break + if model is None: + raise ValueError( + "Unable to find a default value for name=%r." % name) + + # downloads + last_name = model.split('/')[-1] + if cache is None: + cache = os.path.abspath('.') + dest = os.path.join(cache, last_name) + if not os.path.exists(dest): + _download_url(model, dest, name, verbose=verbose) + size = os.stat(dest).st_size + if size < 2 ** 20: # pragma: no cover + os.remove(dest) + raise RuntimeError( + "Unable to download model from %r." % model) + + outtar = os.path.splitext(dest)[0] + if not os.path.exists(outtar): + from pyquickhelper.filehelper.compression_helper import ( + ungzip_files) + ungzip_files(dest, unzip=False, where_to=cache, remove_space=False) + + onnx_file = os.path.splitext(outtar)[0] + if not os.path.exists(onnx_file): + from pyquickhelper.filehelper.compression_helper import ( + untar_files) + untar_files(outtar, where_to=cache) + + if suggested_folder is not None: + fold_onnx = [suggested_folder] + else: + fold_onnx = [onnx_file, onnx_file.split('-')[0], + '-'.join(onnx_file.split('-')[:-1]), + '-'.join(onnx_file.split('-')[:-1]).replace('-', '_')] + fold_onnx_ok = [_ for _ in fold_onnx if os.path.exists(_)] + if len(fold_onnx_ok) != 1: + raise FileNotFoundError( # pragma: no cover + "Unable to find an existing folder among %r." % fold_onnx) + onnx_file = fold_onnx_ok[0] + + onnx_files = [_ for _ in os.listdir(onnx_file) if _.endswith(".onnx")] + if len(onnx_files) != 1: + raise FileNotFoundError( # pragma: no cover + "Unable to find any onnx file in %r." % onnx_files) + final_onnx = os.path.join(onnx_file, onnx_files[0]) + + # data + data = [_ for _ in os.listdir(onnx_file) + if os.path.isdir(os.path.join(onnx_file, _))] + examples = OrderedDict() + for f in data: + examples[f] = load_data(os.path.join(onnx_file, f)) + + return final_onnx, examples + + +def verify_model(onnx_file, examples, runtime=None, abs_tol=5e-4, + verbose=0, fLOG=None): + """ + Verifies a model. + + :param onnx_file: ONNX file + :param examples: list of examples to verify + :param runtime: a runtime to use + :param abs_tol: error tolerance when checking the output + :param verbose: verbosity level for for runtime other than + `'onnxruntime'` + :param fLOG: logging function when `verbose > 0` + :return: errors for every sample + """ + if runtime == 'onnxruntime': + from onnxruntime import InferenceSession + sess = InferenceSession(onnx_file) + meth = lambda data, s=sess: s.run(None, data) + names = [p.name for p in sess.get_inputs()] + onames = list(range(len(sess.get_outputs()))) + else: + def _lin_(sess, data, names): + r = sess.run(data, verbose=verbose, fLOG=fLOG) + return [r[n] for n in names] + + from ..onnxrt import OnnxInference + sess = OnnxInference(onnx_file, runtime=runtime) + names = sess.input_names + onames = sess.output_names + meth = lambda data, s=sess, ns=onames: _lin_(s, data, ns) + + rows = [] + for index, (name, data_inout) in enumerate(examples.items()): + data = data_inout["in"] + if len(data) != len(names): + raise RuntimeError( + "Mismathed number of inputs %d != %d\ninputs: %r\nmodel: %r." + "" % (len(data), len(names), list(sorted(data)), names)) + inputs = {n: data[v] for n, v in zip(names, data)} + outputs = meth(inputs) + expected = data_inout['out'] + if len(outputs) != len(onames): + raise RuntimeError( + "Number of outputs %d is != expected outputs %d." % ( + len(outputs), len(onames))) + for i, (output, expect) in enumerate(zip(outputs, expected.items())): + if output.shape != expect[1].shape: + raise ValueError( + "Shape mismatch got %r != expected %r." % ( + output.shape, expect[1].shape)) + diff = numpy.abs(output - expect[1]).ravel() + absolute = diff.max() + relative = absolute / numpy.median(diff) if absolute > 0 else 0. + if absolute > abs_tol: + raise ValueError( + "Example %d, inferred and expected resuls are different " + "for output %d: abs=%r rel=%r (runtime=%r)." + "" % (index, i, absolute, relative, runtime)) + rows.append(dict(name=name, i=i, abs=absolute, rel=relative)) + return rows