From 191bbc272f0312e76385462338a502fed9a4a0ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Wed, 6 Oct 2021 20:52:54 +0200 Subject: [PATCH 1/3] add more tests --- _unittests/ut_testing/test_einsum_bug.py | 37 +++++++++++++++++++++++- mlprodict/testing/einsum/einsum_fct.py | 2 ++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_testing/test_einsum_bug.py b/_unittests/ut_testing/test_einsum_bug.py index 8ffa0ff0e..4cec03cfd 100644 --- a/_unittests/ut_testing/test_einsum_bug.py +++ b/_unittests/ut_testing/test_einsum_bug.py @@ -2,8 +2,11 @@ @brief test log(time=3s) """ import unittest +import numpy from pyquickhelper.pycode import ExtTestCase -from mlprodict.testing.einsum import decompose_einsum_equation +from mlprodict.testing.einsum import ( + decompose_einsum_equation, optimize_decompose_einsum_equation) +from mlprodict.onnxrt import OnnxInference class TestEinsumBug(ExtTestCase): @@ -20,6 +23,38 @@ def test__pprint_forward(self): spl = pf.split("<- id") self.assertEqual(len(spl), 4) + def common_test_equation(self, equation, dim1, dim2): + seq = decompose_einsum_equation( + equation, clean=True, strategy='numpy') + onx = seq.to_onnx('Y', 'X1', 'X2') + seq = equation.replace(",", "_").replace("->", "__") + with open("temp_%s_A.onnx" % seq, "wb") as f: + f.write(onx.SerializeToString()) + a = numpy.random.rand(*list((2, ) * dim1)) + b = numpy.random.rand(*list((2, ) * dim2)) + oinf = OnnxInference(onx) + got = oinf.run({'X1': a, 'X2': b}) + expected = numpy.einsum(equation, a, b) + self.assertEqualArray(expected, got['Y']) + + res = optimize_decompose_einsum_equation( + equation, numpy.float64, optimize=True, runtime="python", + cache=False, opset=15, decompose=True, strategy='ml', + verbose=None) + new_eq = res.equation_ + new_onx = res.onnx_ + with open("temp_%s_B.onnx" % seq, "wb") as f: + f.write(onx.SerializeToString()) + oinf = OnnxInference(onx) + got = oinf.run({'X1': a, 'X2': b}) + self.assertEqualArray(expected, got['Y']) + + def test_decompose_einsum_abc_cde_abde(self): + self.common_test_equation("abc,cde->abde", 3, 3) + + def test_decompose_einsum_abcd_cde_abe(self): + self.common_test_equation("abcd,cde->abe", 4, 3) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/testing/einsum/einsum_fct.py b/mlprodict/testing/einsum/einsum_fct.py index c46579e81..d9c8b6e34 100644 --- a/mlprodict/testing/einsum/einsum_fct.py +++ b/mlprodict/testing/einsum/einsum_fct.py @@ -333,6 +333,8 @@ def _einsum(equation, dtype, optimize=False, runtime="batch_dot", if cache: key = equation, runtime, opset, optimize, dtype, decompose, strategy cached = _einsum_cache.get(key, None) + else: + key = None if cached is None: cached = CachedEinsum.build_einsum( equation, runtime, opset, optimize, From fd9243613d64a4bde883505618bf55a5fe7a7602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Thu, 7 Oct 2021 10:41:17 +0200 Subject: [PATCH 2/3] lint --- mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py index 4b5735e97..9cd057c25 100644 --- a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py +++ b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py @@ -157,7 +157,7 @@ class TreeEnsembleRegressorDouble(TreeEnsembleRegressorCommon): onx64 = to_onnx(model, X, rewrite_ops=True, target_opset=15) assert 'TreeEnsembleRegressorDouble' in str(onx64) expected = model.predict(X) - + oinf = OnnxInference(onx64) got = oinf.run({'X': X}) diff = numpy.abs(got['variable'] - expected) From df795a3bed856008675b20010a8e702fe90bee94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Thu, 7 Oct 2021 11:02:15 +0200 Subject: [PATCH 3/3] lint --- _unittests/ut_testing/test_einsum_bug.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/_unittests/ut_testing/test_einsum_bug.py b/_unittests/ut_testing/test_einsum_bug.py index 4cec03cfd..1f3b7d085 100644 --- a/_unittests/ut_testing/test_einsum_bug.py +++ b/_unittests/ut_testing/test_einsum_bug.py @@ -27,26 +27,27 @@ def common_test_equation(self, equation, dim1, dim2): seq = decompose_einsum_equation( equation, clean=True, strategy='numpy') onx = seq.to_onnx('Y', 'X1', 'X2') - seq = equation.replace(",", "_").replace("->", "__") - with open("temp_%s_A.onnx" % seq, "wb") as f: + sequ = equation.replace(",", "_").replace("->", "__") + with open("temp_%s_A.onnx" % sequ, "wb") as f: f.write(onx.SerializeToString()) a = numpy.random.rand(*list((2, ) * dim1)) b = numpy.random.rand(*list((2, ) * dim2)) - oinf = OnnxInference(onx) + oinf = OnnxInference(onx) got = oinf.run({'X1': a, 'X2': b}) expected = numpy.einsum(equation, a, b) self.assertEqualArray(expected, got['Y']) - + res = optimize_decompose_einsum_equation( equation, numpy.float64, optimize=True, runtime="python", cache=False, opset=15, decompose=True, strategy='ml', verbose=None) new_eq = res.equation_ new_onx = res.onnx_ - with open("temp_%s_B.onnx" % seq, "wb") as f: - f.write(onx.SerializeToString()) - oinf = OnnxInference(onx) - got = oinf.run({'X1': a, 'X2': b}) + sequ = new_eq.replace(",", "_").replace("->", "__") + with open("temp_%s_B.onnx" % sequ, "wb") as f: + f.write(new_onx.SerializeToString()) + oinf = OnnxInference(new_onx) + got = oinf.run({'X0': a, 'X1': b}) self.assertEqualArray(expected, got['Y']) def test_decompose_einsum_abc_cde_abde(self):