Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
Fix bug while removing identity nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Sep 10, 2019
1 parent 8d7b633 commit f56d486
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
40 changes: 40 additions & 0 deletions _unittests/ut_onnxrt/test_onnxrt_simple_gaussian_process.py
@@ -0,0 +1,40 @@
"""
@brief test log(time=2s)
"""
import unittest
from logging import getLogger
import numpy
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ExpSineSquared
from pyquickhelper.pycode import ExtTestCase
from skl2onnx import __version__ as skl2onnx_version
from mlprodict.onnxrt import OnnxInference, to_onnx
from mlprodict.onnxrt.optim import onnx_optimisations


class TestOnnxrtSimpleGaussianProcess(ExtTestCase):

def setUp(self):
logger = getLogger('skl2onnx')
logger.disabled = True

def test_onnxt_gpr_iris(self):
iris = load_iris()
X, y = iris.data, iris.target
X_train, _, y_train, __ = train_test_split(X, y, random_state=11)
clr = GaussianProcessRegressor(ExpSineSquared(), alpha=20.)
clr.fit(X_train, y_train)

model_def = to_onnx(clr, X_train, dtype=numpy.float64)
oinf = OnnxInference(model_def)
res1 = oinf.run({'X': X_train})
new_model = onnx_optimisations(model_def)
oinf = OnnxInference(new_model)
res2 = oinf.run({'X': X_train})
self.assertEqualArray(res1['GPmean'], res2['GPmean'])


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions mlprodict/onnxrt/optim/onnx_optimisation_identity.py
Expand Up @@ -79,6 +79,7 @@ def retrieve_idnodes(graph, existing_nodes):
restart = True
nodes[i] = None
rem += 1
break
if not restart and inp not in inputs:
# We cannot change an input name.
for j in range(len(nodes)): # pylint: disable=C0200
Expand Down
2 changes: 1 addition & 1 deletion mlprodict/onnxrt/validate/validate.py
Expand Up @@ -454,7 +454,7 @@ def fct_batch(se=sess, xo=Xort_test, it=init_types): # pylint: disable=W0102
{init_types[0][0]: xo}, node_time=node_time), Xort_test)
except (RuntimeError, TypeError, ValueError, KeyError, IndexError) as e:
if debug:
raise e
raise RuntimeError("Issue with {}.".format(obs_op)) from e
obs_op['_6ort_run_batch_exc'] = e
if (benchmark or node_time) and 'lambda-batch' in obs_op:
try:
Expand Down

0 comments on commit f56d486

Please sign in to comment.