diff --git a/_unittests/ut_onnxrt/test_cpu_ops.py b/_unittests/ut_onnxrt/test_cpu_ops.py index b03808c25..3d9837be2 100644 --- a/_unittests/ut_onnxrt/test_cpu_ops.py +++ b/_unittests/ut_onnxrt/test_cpu_ops.py @@ -5,13 +5,18 @@ from logging import getLogger import numpy import onnx -from pyquickhelper.pycode import ExtTestCase +from sklearn.ensemble import RandomForestClassifier +from sklearn.multiclass import OneVsRestClassifier +from pyquickhelper.pycode import ExtTestCase, ignore_warnings from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611 OnnxConv) +from mlprodict.onnx_conv import to_onnx from mlprodict.onnxrt.ops_cpu.op_conv import Conv from mlprodict.onnxrt.onnx2py_helper import _var_as_dict from mlprodict.tools.asv_options_helper import get_opset_number_from_onnx from mlprodict.onnxrt import OnnxInference +from mlprodict.testing.test_utils.tests_helper import fit_multilabel_classification_model +from mlprodict.testing.test_utils import TARGET_OPSET class TestCpuOps(ExtTestCase): @@ -20,6 +25,7 @@ def setUp(self): logger = getLogger('skl2onnx') logger.disabled = True + @ignore_warnings(DeprecationWarning) def test_cpu_conv(self): x = numpy.array([[[[0., 1., 2., 3., 4.], # (1, 1, 5, 5) input tensor @@ -132,6 +138,28 @@ def test_cpu_conv_group(self): ii, diff[ii], gotrt['Y'].ravel()[ii], got['Y'].ravel()[ii])) self.assertEqualArray(gotrt['Y'], got['Y'], decimal=5) + def test_slice_bug(self): + + for opset in [9, 12, TARGET_OPSET]: + if opset > TARGET_OPSET: + continue + model = OneVsRestClassifier( + RandomForestClassifier(n_estimators=2, max_depth=3)) + model, X = fit_multilabel_classification_model( + model, 3, is_int=False, n_features=5) + model_onnx = to_onnx( + model, X[:1], target_opset=opset, + options={id(model): {'zipmap': False}}) + X = X[:7] + for rt in ['python', 'onnxruntime1']: + with self.subTest(opset=opset, rt=rt): + oinf = OnnxInference(model_onnx, runtime=rt) + got = oinf.run({'X': X}) + exp = model.predict(X), model.predict_proba(X) + self.assertEqual(exp[1].shape[1], 3) + self.assertEqualArray(exp[0], got['label']) + self.assertEqualArray(exp[1], got['probabilities']) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index c59c943c7..74465aa59 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -2161,60 +2161,79 @@ def test_onnxt_runtime_sin(self): @wraplog() def test_onnxt_runtime_slice(self): - # steps - x = numpy.random.randn(20, 10, 5).astype( # pylint: disable=E1101 - numpy.float32) # pylint: disable=E1101 - y = x[0:3:2, 0:10:2] - starts = numpy.array([0, 0], dtype=numpy.int64) - ends = numpy.array([3, 10], dtype=numpy.int64) - axes = numpy.array([0, 1], dtype=numpy.int64) - steps = numpy.array([2, 2], dtype=numpy.int64) - onx = OnnxSlice('X', starts, ends, axes, steps, output_names=['Y'], - op_version=get_opset_number_from_onnx()) - model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, - target_opset=get_opset_number_from_onnx()) - got = OnnxInference(model_def).run({'X': x}) - self.assertEqualArray(y, got['Y']) - - # other - x = numpy.random.randn(20, 10, 5).astype( # pylint: disable=E1101 - numpy.float32) # pylint: disable=E1101 - y = x[0:3, 0:10] - starts = numpy.array([0, 0], dtype=numpy.int64) - ends = numpy.array([3, 10], dtype=numpy.int64) - onx = OnnxSlice('X', starts, ends, output_names=['Y'], - op_version=get_opset_number_from_onnx()) - model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, - target_opset=get_opset_number_from_onnx()) - got = OnnxInference(model_def).run({'X': x}) - self.assertEqualArray(y, got['Y']) - - x = numpy.random.randn(20, 10, 5).astype( # pylint: disable=E1101 - numpy.float32) # pylint: disable=E1101 - y = x[0:3, 0:10] - starts = numpy.array([0, 0], dtype=numpy.int64) - ends = numpy.array([3, 10], dtype=numpy.int64) - axes = numpy.array([0, 1], dtype=numpy.int64) - onx = OnnxSlice('X', starts, ends, axes, output_names=['Y'], - op_version=get_opset_number_from_onnx()) - model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, - target_opset=get_opset_number_from_onnx()) - got = OnnxInference(model_def).run({'X': x}) - self.assertEqualArray(y, got['Y']) - - x = numpy.random.randn(20, 10, 5).astype( # pylint: disable=E1101 - numpy.float32) # pylint: disable=E1101 - y = x[0:3:-1, 0:10:2] - starts = numpy.array([0, 0], dtype=numpy.int64) - ends = numpy.array([3, 10], dtype=numpy.int64) - axes = numpy.array([0, 1], dtype=numpy.int64) - steps = numpy.array([-1, 2], dtype=numpy.int64) - onx = OnnxSlice('X', starts, ends, axes, steps, output_names=['Y'], - op_version=get_opset_number_from_onnx()) - model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, - target_opset=get_opset_number_from_onnx()) - got = OnnxInference(model_def).run({'X': x}) - self.assertEqualArray(y, got['Y']) + for opset in [9, get_opset_number_from_onnx()]: + if opset > get_opset_number_from_onnx(): + continue + with self.subTest(opset=opset): + # steps + x = numpy.random.randn(20, 10, 5).astype( # pylint: disable=E1101 + numpy.float32) # pylint: disable=E1101 + y = x[0:3:2, 0:10:2] + starts = numpy.array([0, 0], dtype=numpy.int64) + ends = numpy.array([3, 10], dtype=numpy.int64) + axes = numpy.array([0, 1], dtype=numpy.int64) + steps = numpy.array([2, 2], dtype=numpy.int64) + if opset < 10: + onx = OnnxSlice('X', starts=starts, ends=ends, axes=axes, + output_names=['Y'], op_version=opset) + y = x[0:3, 0:10] + else: + onx = OnnxSlice('X', starts, ends, axes, steps, + output_names=['Y'], op_version=opset) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=opset) + got = OnnxInference(model_def).run({'X': x}) + self.assertEqualArray(y, got['Y']) + + # other + x = numpy.random.randn(20, 10, 5).astype( # pylint: disable=E1101 + numpy.float32) + y = x[0:3, 0:10] + starts = numpy.array([0, 0], dtype=numpy.int64) + ends = numpy.array([3, 10], dtype=numpy.int64) + if opset < 10: + onx = OnnxSlice('X', starts=starts, ends=ends, + output_names=['Y'], op_version=opset) + else: + onx = OnnxSlice('X', starts, ends, output_names=['Y'], + op_version=opset) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=opset) + got = OnnxInference(model_def).run({'X': x}) + self.assertEqualArray(y, got['Y']) + + x = numpy.random.randn(20, 10, 5).astype( # pylint: disable=E1101 + numpy.float32) + y = x[0:3, 0:10] + starts = numpy.array([0, 0], dtype=numpy.int64) + ends = numpy.array([3, 10], dtype=numpy.int64) + axes = numpy.array([0, 1], dtype=numpy.int64) + if opset < 10: + onx = OnnxSlice('X', starts=starts, ends=ends, axes=axes, + output_names=['Y'], op_version=opset) + else: + onx = OnnxSlice('X', starts, ends, axes, output_names=['Y'], + op_version=opset) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=opset) + got = OnnxInference(model_def).run({'X': x}) + self.assertEqualArray(y, got['Y']) + + if opset < 10: + continue + x = numpy.random.randn(20, 10, 5).astype( # pylint: disable=E1101 + numpy.float32) + y = x[0:3:-1, 0:10:2] + starts = numpy.array([0, 0], dtype=numpy.int64) + ends = numpy.array([3, 10], dtype=numpy.int64) + axes = numpy.array([0, 1], dtype=numpy.int64) + steps = numpy.array([-1, 2], dtype=numpy.int64) + onx = OnnxSlice('X', starts, ends, axes, steps, output_names=['Y'], + op_version=opset) + model_def = onx.to_onnx({'X': x.astype(numpy.float32)}, + target_opset=opset) + got = OnnxInference(model_def).run({'X': x}) + self.assertEqualArray(y, got['Y']) python_tested.append(OnnxSlice) @wraplog() @@ -2792,5 +2811,5 @@ def test_make_constant(self): if __name__ == "__main__": - # TestOnnxrtPythonRuntime().test_make_constant() + # TestOnnxrtPythonRuntime().test_onnxt_runtime_slice() unittest.main() diff --git a/mlprodict/onnxrt/ops_cpu/_op.py b/mlprodict/onnxrt/ops_cpu/_op.py index 50a99bf27..62f12e819 100644 --- a/mlprodict/onnxrt/ops_cpu/_op.py +++ b/mlprodict/onnxrt/ops_cpu/_op.py @@ -558,8 +558,13 @@ def __init__(self, numpy_fct, onnx_node, desc=None, expected_attributes=expected_attributes, **options) self.numpy_fct = numpy_fct + self._cannot_inplace_int = self.numpy_fct in ( + numpy.divide, numpy.true_divide) def _run(self, a, b): # pylint: disable=W0221 + if (self._cannot_inplace_int and + numpy.issubdtype(a.dtype, numpy.integer)): + return (self.numpy_fct(a, b), ) if self.inplaces.get(0, False) and a.size >= b.size: if len(a.shape) == 1 and b.shape == (1, 1): a = a.reshape(1, a.shape[0]) diff --git a/mlprodict/onnxrt/ops_cpu/_op_list.py b/mlprodict/onnxrt/ops_cpu/_op_list.py index 3397b4765..b4671a4f2 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_list.py +++ b/mlprodict/onnxrt/ops_cpu/_op_list.py @@ -86,7 +86,7 @@ from .op_sigmoid import Sigmoid from .op_sign import Sign from .op_sin import Sin -from .op_slice import Slice +from .op_slice import Slice, Slice_1, Slice_10 from .op_split import Split from .op_softmax import Softmax from .op_solve import Solve diff --git a/mlprodict/onnxrt/ops_cpu/op_slice.py b/mlprodict/onnxrt/ops_cpu/op_slice.py index aa1e72f75..a14d53cb8 100644 --- a/mlprodict/onnxrt/ops_cpu/op_slice.py +++ b/mlprodict/onnxrt/ops_cpu/op_slice.py @@ -4,11 +4,12 @@ @file @brief Runtime operator. """ -from ._op import OpRun +from onnx.defs import onnx_opset_version from ..shape_object import ShapeObject +from ._op import OpRun -class Slice(OpRun): +class SliceCommon(OpRun): def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, @@ -34,5 +35,40 @@ def _run(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0 def _infer_shapes(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221 pref = str(hex(id(self))[2:]) - shape = ["nslice%s_%d" % (pref, i) for i in range(len(data))] + shape = ["nslice%s_%d" % (pref, i) for i in range(len(data.shape))] return (ShapeObject(shape, data.dtype), ) + + +class Slice_10(SliceCommon): + def __init__(self, onnx_node, desc=None, **options): + SliceCommon.__init__(self, onnx_node, desc=desc, + **options) + + +class Slice_1(SliceCommon): + + atts = {'starts': [], 'ends': [], 'axes': []} + + def __init__(self, onnx_node, desc=None, **options): + SliceCommon.__init__(self, onnx_node, desc=desc, + expected_attributes=Slice_1.atts, + **options) + for f in ['starts', 'ends', 'steps', 'axes']: + if not hasattr(self, f): + continue + if getattr(self, f) is not None and len(getattr(self, f)) == 0: + setattr(self, f, None) + + def _run(self, data): # pylint: disable=W0221 + return SliceCommon._run( + self, data, self.starts, self.ends, self.axes) + + def _infer_shapes(self, data): # pylint: disable=W0221 + return SliceCommon._infer_shapes( + self, data, self.starts, self.ends, self.axes) + + +if onnx_opset_version() >= 10: + Slice = Slice_10 +else: + Slice = Slice_1 # pragma: no cover diff --git a/mlprodict/onnxrt/shape_object.py b/mlprodict/onnxrt/shape_object.py index 45b0fea9e..72c795831 100644 --- a/mlprodict/onnxrt/shape_object.py +++ b/mlprodict/onnxrt/shape_object.py @@ -45,7 +45,8 @@ def __init__(self, name, fct, fct_string, *args): for a in self._args: if not isinstance(a, DimensionObject): raise TypeError( - "All arguments must be of type DimensionObject not '{}'.".format(type(a))) + "All arguments must be of type DimensionObject not '{}'." + "".format(type(a))) def __repr__(self): """ @@ -649,7 +650,7 @@ def __repr__(self): st_shape = [] for s in self.shape: - if isinstance(s._dim, (int, str)): + if isinstance(getattr(s, "_dim", None), (int, str)): st_shape.append(str(s._dim)) else: st_shape.append(repr(s))