diff --git a/_unittests/ut_testing/test_einsum_bug.py b/_unittests/ut_testing/test_einsum_bug.py index 8ffa0ff0e..1f3b7d085 100644 --- a/_unittests/ut_testing/test_einsum_bug.py +++ b/_unittests/ut_testing/test_einsum_bug.py @@ -2,8 +2,11 @@ @brief test log(time=3s) """ import unittest +import numpy from pyquickhelper.pycode import ExtTestCase -from mlprodict.testing.einsum import decompose_einsum_equation +from mlprodict.testing.einsum import ( + decompose_einsum_equation, optimize_decompose_einsum_equation) +from mlprodict.onnxrt import OnnxInference class TestEinsumBug(ExtTestCase): @@ -20,6 +23,39 @@ def test__pprint_forward(self): spl = pf.split("<- id") self.assertEqual(len(spl), 4) + def common_test_equation(self, equation, dim1, dim2): + seq = decompose_einsum_equation( + equation, clean=True, strategy='numpy') + onx = seq.to_onnx('Y', 'X1', 'X2') + sequ = equation.replace(",", "_").replace("->", "__") + with open("temp_%s_A.onnx" % sequ, "wb") as f: + f.write(onx.SerializeToString()) + a = numpy.random.rand(*list((2, ) * dim1)) + b = numpy.random.rand(*list((2, ) * dim2)) + oinf = OnnxInference(onx) + got = oinf.run({'X1': a, 'X2': b}) + expected = numpy.einsum(equation, a, b) + self.assertEqualArray(expected, got['Y']) + + res = optimize_decompose_einsum_equation( + equation, numpy.float64, optimize=True, runtime="python", + cache=False, opset=15, decompose=True, strategy='ml', + verbose=None) + new_eq = res.equation_ + new_onx = res.onnx_ + sequ = new_eq.replace(",", "_").replace("->", "__") + with open("temp_%s_B.onnx" % sequ, "wb") as f: + f.write(new_onx.SerializeToString()) + oinf = OnnxInference(new_onx) + got = oinf.run({'X0': a, 'X1': b}) + self.assertEqualArray(expected, got['Y']) + + def test_decompose_einsum_abc_cde_abde(self): + self.common_test_equation("abc,cde->abde", 3, 3) + + def test_decompose_einsum_abcd_cde_abe(self): + self.common_test_equation("abcd,cde->abe", 4, 3) + if __name__ == "__main__": unittest.main() diff --git a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py index 4b5735e97..9cd057c25 100644 --- a/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py +++ b/mlprodict/onnxrt/ops_cpu/op_tree_ensemble_regressor.py @@ -157,7 +157,7 @@ class TreeEnsembleRegressorDouble(TreeEnsembleRegressorCommon): onx64 = to_onnx(model, X, rewrite_ops=True, target_opset=15) assert 'TreeEnsembleRegressorDouble' in str(onx64) expected = model.predict(X) - + oinf = OnnxInference(onx64) got = oinf.run({'X': X}) diff = numpy.abs(got['variable'] - expected) diff --git a/mlprodict/testing/einsum/einsum_fct.py b/mlprodict/testing/einsum/einsum_fct.py index c46579e81..d9c8b6e34 100644 --- a/mlprodict/testing/einsum/einsum_fct.py +++ b/mlprodict/testing/einsum/einsum_fct.py @@ -333,6 +333,8 @@ def _einsum(equation, dtype, optimize=False, runtime="batch_dot", if cache: key = equation, runtime, opset, optimize, dtype, decompose, strategy cached = _einsum_cache.get(key, None) + else: + key = None if cached is None: cached = CachedEinsum.build_einsum( equation, runtime, opset, optimize,