From f56d486d32215002c8e9d46ceee155a5447c5881 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 11 Sep 2019 00:27:20 +0200 Subject: [PATCH] Fix bug while removing identity nodes --- .../test_onnxrt_simple_gaussian_process.py | 40 +++++++++++++++++++ .../optim/onnx_optimisation_identity.py | 1 + mlprodict/onnxrt/validate/validate.py | 2 +- 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 _unittests/ut_onnxrt/test_onnxrt_simple_gaussian_process.py diff --git a/_unittests/ut_onnxrt/test_onnxrt_simple_gaussian_process.py b/_unittests/ut_onnxrt/test_onnxrt_simple_gaussian_process.py new file mode 100644 index 000000000..a93785317 --- /dev/null +++ b/_unittests/ut_onnxrt/test_onnxrt_simple_gaussian_process.py @@ -0,0 +1,40 @@ +""" +@brief test log(time=2s) +""" +import unittest +from logging import getLogger +import numpy +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import ExpSineSquared +from pyquickhelper.pycode import ExtTestCase +from skl2onnx import __version__ as skl2onnx_version +from mlprodict.onnxrt import OnnxInference, to_onnx +from mlprodict.onnxrt.optim import onnx_optimisations + + +class TestOnnxrtSimpleGaussianProcess(ExtTestCase): + + def setUp(self): + logger = getLogger('skl2onnx') + logger.disabled = True + + def test_onnxt_gpr_iris(self): + iris = load_iris() + X, y = iris.data, iris.target + X_train, _, y_train, __ = train_test_split(X, y, random_state=11) + clr = GaussianProcessRegressor(ExpSineSquared(), alpha=20.) + clr.fit(X_train, y_train) + + model_def = to_onnx(clr, X_train, dtype=numpy.float64) + oinf = OnnxInference(model_def) + res1 = oinf.run({'X': X_train}) + new_model = onnx_optimisations(model_def) + oinf = OnnxInference(new_model) + res2 = oinf.run({'X': X_train}) + self.assertEqualArray(res1['GPmean'], res2['GPmean']) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlprodict/onnxrt/optim/onnx_optimisation_identity.py b/mlprodict/onnxrt/optim/onnx_optimisation_identity.py index d4ac67428..07e4a8556 100644 --- a/mlprodict/onnxrt/optim/onnx_optimisation_identity.py +++ b/mlprodict/onnxrt/optim/onnx_optimisation_identity.py @@ -79,6 +79,7 @@ def retrieve_idnodes(graph, existing_nodes): restart = True nodes[i] = None rem += 1 + break if not restart and inp not in inputs: # We cannot change an input name. for j in range(len(nodes)): # pylint: disable=C0200 diff --git a/mlprodict/onnxrt/validate/validate.py b/mlprodict/onnxrt/validate/validate.py index c17fb5fa8..72156a8ce 100644 --- a/mlprodict/onnxrt/validate/validate.py +++ b/mlprodict/onnxrt/validate/validate.py @@ -454,7 +454,7 @@ def fct_batch(se=sess, xo=Xort_test, it=init_types): # pylint: disable=W0102 {init_types[0][0]: xo}, node_time=node_time), Xort_test) except (RuntimeError, TypeError, ValueError, KeyError, IndexError) as e: if debug: - raise e + raise RuntimeError("Issue with {}.".format(obs_op)) from e obs_op['_6ort_run_batch_exc'] = e if (benchmark or node_time) and 'lambda-batch' in obs_op: try: