From d20a78e8405d568ae1528a51067271019a71d6e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 21 Feb 2021 18:09:45 +0100 Subject: [PATCH 1/4] Implements __setitem__ for class OnnxVar --- _unittests/ut_npy/test_onnx_variable.py | 34 +++++- .../ut_onnxrt/test_onnxrt_python_runtime_.py | 47 +++++++- mlprodict/npy/onnx_variable.py | 100 +++++++++++++++++- mlprodict/onnxrt/ops_cpu/_op_list.py | 1 + .../onnxrt/ops_cpu/op_scatter_elements.py | 63 +++++++++++ 5 files changed, 241 insertions(+), 4 deletions(-) create mode 100644 mlprodict/onnxrt/ops_cpu/op_scatter_elements.py diff --git a/_unittests/ut_npy/test_onnx_variable.py b/_unittests/ut_npy/test_onnx_variable.py index a565bb69f..535f11628 100644 --- a/_unittests/ut_npy/test_onnx_variable.py +++ b/_unittests/ut_npy/test_onnx_variable.py @@ -238,10 +238,28 @@ def test_abs_not(x: NDArray[Any, numpy.float32], @onnxnumpy_default def test_abs_filter(x: NDArray[Any, numpy.float32], ) -> NDArray[Any, numpy.float32]: - "onnx numpy not" + "onnx numpy filter" return nxnp.abs(x)[x[:, 0] > numpy.float32(15)] +@onnxnumpy_default +def test_abs_set2(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.bool]: + "onnx numpy set" + temp = nxnp.abs(x).copy() + temp[:2, 0] = numpy.float32(-1) + return temp + + +@onnxnumpy_default +def test_abs_set3(x: NDArray[Any, numpy.float32], + ) -> NDArray[Any, numpy.bool]: + "onnx numpy set" + temp = nxnp.abs(x).copy() + temp[:2, :1] = numpy.array([[-1.5, -1.5]], dtype=numpy.float32).T + return temp + + class TestOnnxVariable(ExtTestCase): def test_onnx_variable_abs(self): @@ -414,6 +432,20 @@ def test_onnx_variable_abs_filter(self): y = test_abs_filter(x) self.assertEqualArray(y, numpy.abs(x)[x[:, 0] > 15]) + def test_onnx_variable_abs_set(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_set2(x) + temp = numpy.abs(x) + temp[:, 0] = -1 + self.assertEqualArray(y, temp) + + def test_onnx_variable_abs_set3(self): + x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) + y = test_abs_set3(x) + temp = numpy.abs(x) + temp[:, 0] = -1.5 + self.assertEqualArray(y, temp) + if __name__ == "__main__": unittest.main() diff --git a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py index e11042025..14f062fab 100644 --- a/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py +++ b/_unittests/ut_onnxrt/test_onnxrt_python_runtime_.py @@ -49,7 +49,7 @@ OnnxReduceSum, OnnxReduceSumApi11, OnnxReduceSum_11, OnnxReduceSum_1, OnnxReduceSumSquare, OnnxRelu, OnnxReshape, - OnnxShape, OnnxSlice, OnnxSigmoid, OnnxSign, OnnxSin, + OnnxScatterElements, OnnxShape, OnnxSlice, OnnxSigmoid, OnnxSign, OnnxSin, OnnxSplitApi11, OnnxSoftmax, OnnxSplit, OnnxSqrt, OnnxSub, OnnxSum, @@ -2320,6 +2320,51 @@ def test_onnxt_runtime_reshape(self): self.assertEqualArray(exp, got['Y']) python_tested.append(OnnxReshape) + @wraplog() + def test_onnxt_runtime_scatter_elements1(self): + for opset in [11, get_opset_number_from_onnx()]: + if opset > get_opset_number_from_onnx(): + continue + with self.subTest(opset=opset): + data = numpy.array( + [[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=numpy.float32) + indices = numpy.array([[1, 3]], dtype=numpy.int64) + updates = numpy.array([[1.1, 2.1]], dtype=numpy.float32) + output = numpy.array( + [[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=numpy.float32) + + onx = OnnxScatterElements( + 'X', indices, updates, axis=1, + output_names=['Y'], op_version=opset) + model_def = onx.to_onnx( + {'X': data}, target_opset=opset) + got = OnnxInference(model_def).run({'X': data}) + self.assertEqualArray(output, got['Y']) + python_tested.append(OnnxScatterElements) + + @wraplog() + def test_onnxt_runtime_scatter_elements2(self): + for opset in [11, get_opset_number_from_onnx()]: + if opset > get_opset_number_from_onnx(): + continue + with self.subTest(opset=opset): + x = numpy.arange(20).reshape((4, 5)).astype( # pylint: disable=E1101 + numpy.float32) # pylint: disable=E1101 + indices = numpy.array([[1, 1, 1, 1]], dtype=numpy.int64).T + updates = numpy.array( + [[-1, -1, -1, -1]], dtype=numpy.float32).T + y = x.copy() + y[:, 1] = -1 + + onx = OnnxScatterElements( + 'X', indices, updates, axis=1, + output_names=['Y'], op_version=opset) + model_def = onx.to_onnx( + {'X': x, 'indices': indices, 'updates': updates}, + target_opset=opset) + got = OnnxInference(model_def).run({'X': x}) + self.assertEqualArray(y, got['Y']) + @wraplog() def test_onnxt_runtime_shape(self): x = numpy.random.randn(20, 2).astype( # pylint: disable=E1101 diff --git a/mlprodict/npy/onnx_variable.py b/mlprodict/npy/onnx_variable.py index 19b047093..037aa0bf8 100644 --- a/mlprodict/npy/onnx_variable.py +++ b/mlprodict/npy/onnx_variable.py @@ -11,17 +11,28 @@ OnnxDiv, OnnxEqual, OnnxGather, OnnxGreater, + OnnxIdentity, OnnxLess, OnnxMatMul, OnnxMod, OnnxMul, OnnxNeg, OnnxNot, OnnxOr, OnnxPow, OnnxReduceSum, OnnxReshape, - OnnxSlice, OnnxSqueeze, OnnxSub, + OnnxScatterElements, OnnxSlice, OnnxSqueeze, OnnxSub, OnnxTopK, OnnxTranspose ) +try: + numpy_bool = numpy.bool_ +except AttributeError: + numpy_bool = bool +try: + numpy_str = numpy.str +except AttributeError: + numpy_str = str + + class OnnxVar: """ Variables used into :epkg:`onnx` computation. @@ -61,6 +72,7 @@ def to_algebra(self, op_version=None): if not hasattr(self, 'alg_'): raise RuntimeError( # pragma: no cover "Missing attribute 'alg_'.") + self.alg_ = alg return alg new_inputs = [] @@ -68,12 +80,13 @@ def to_algebra(self, op_version=None): if isinstance(inp, ( int, float, str, numpy.ndarray, numpy.int32, numpy.int64, numpy.float32, numpy.float64, - numpy.bool_, numpy.str, numpy.int8, numpy.uint8, + numpy_bool, numpy_str, numpy.int8, numpy.uint8, numpy.int16, numpy.uint16, numpy.uint32, numpy.uint64)): new_inputs.append(inp) else: new_inputs.append( inp.to_algebra(op_version=op_version)) + res = self.onnx_op(*new_inputs, op_version=op_version, **self.onnx_op_kwargs) if self.select_output is None: @@ -244,3 +257,86 @@ def __getitem__(self, index): if steps is None: return OnnxVar(self, starts, ends, axes, op=OnnxSlice) return OnnxVar(self, starts, ends, axes, steps, op=OnnxSlice) + + def _matrix_multiply(self, indices, axis): + """ + Creates a matrix. + """ + shapes = tuple(i[1] - i[0] for i in indices) + mat = numpy.empty(shapes, dtype=numpy.int64) + ind = [slice(None) for i in indices] + ind[axis] = numpy.arange(0, indices[axis][1] - indices[axis][0]) + values = numpy.arange(indices[axis][0], indices[axis][1]) + mat[ind] = values + return mat + + def __setitem__(self, index, value): + """ + Deals with multiple scenarios. + * *index* is an integer or a slice, a tuple of integers and slices, + example: `[0, 1]`, `[:5, :6]`, `[::2]` (**scenario 1**) + * *index* is an *ONNX* object (more precisely an instance of + @see cl OnnxVar), then the method assumes it is an array of + boolean to select a subset of the tensor along the first axis, + example: `mat[mat == 0]` (**scenario 2**) + This processing is applied before the operator it contains. + A copy should be made (Identity node or copy method). + """ + if self.onnx_op is not None and self.onnx_op is not OnnxIdentity: + raise RuntimeError( + "A copy should be made before setting new values on a matrix. " + "Method copy() would do that.") + if isinstance(index, OnnxVar): + # scenario 2 + raise NotImplementedError() + + if not isinstance(index, tuple): + index = (index, ) + + # scenario 1 + indices = [] + for d, ind in enumerate(index): + if isinstance(ind, int): + indices.append((ind, ind + 1)) + elif isinstance(ind, slice): + if ind.step is not None: + raise NotImplementedError( + "Unable to assign new values with step defined " + "on dimension %r." % d) + start = 0 if ind.start is None else ind.start + if ind.stop is None: + raise NotImplementedError( + "Unable to assign new values with end undefined " + "on dimension %r." % d) + stop = ind.stop + indices.append((start, stop)) + else: + raise NotImplementedError( + "Unable to assign new values due to unexpected type %r " + "on dimension %r." % (type(ind), d)) + + axis = len(index) - 1 + mat_indices = self._matrix_multiply(indices, axis) + + if isinstance(value, (OnnxVar, numpy.ndarray)): + mat_updates = value + elif isinstance(value, (numpy.float32, numpy.float64, numpy.int32, + numpy.int64, numpy.uint32, numpy.uint64, + numpy_bool, numpy_str)): + mat_updates = numpy.full( + mat_indices.shape, value, dtype=value.dtype) + else: + raise NotImplementedError( + "Unable to assign new values due to unexpected type %r " + "for value." % type(value)) + + self.inputs = [ + OnnxVar(self.inputs[0], mat_indices, mat_updates, op=OnnxScatterElements, + axis=axis)] + return self + + def copy(self): + """ + Returns a copy of self (use of Identity node). + """ + return OnnxVar(self, op=OnnxIdentity) diff --git a/mlprodict/onnxrt/ops_cpu/_op_list.py b/mlprodict/onnxrt/ops_cpu/_op_list.py index 17adaa858..1ab782eb1 100644 --- a/mlprodict/onnxrt/ops_cpu/_op_list.py +++ b/mlprodict/onnxrt/ops_cpu/_op_list.py @@ -83,6 +83,7 @@ from .op_rnn import RNN from .op_scaler import Scaler from .op_scan import Scan +from .op_scatter_elements import ScatterElements from .op_shape import Shape from .op_sigmoid import Sigmoid from .op_sign import Sign diff --git a/mlprodict/onnxrt/ops_cpu/op_scatter_elements.py b/mlprodict/onnxrt/ops_cpu/op_scatter_elements.py new file mode 100644 index 000000000..d590de53f --- /dev/null +++ b/mlprodict/onnxrt/ops_cpu/op_scatter_elements.py @@ -0,0 +1,63 @@ +# -*- encoding: utf-8 -*- +# pylint: disable=E0203,E1101,C0111 +""" +@file +@brief Runtime operator. +""" +import numpy +from ..shape_object import ShapeObject +from ._op import OpRun + + +def scatter_elements(data, indices, updates, axis=0): + if axis < 0: + axis = data.ndim + axis + + idx_xsection_shape = indices.shape[:axis] + indices.shape[axis + 1:] + + def make_slice(arr, axis, i): + slc = [slice(None)] * arr.ndim + slc[axis] = i + return slc + + def unpack(packed): + unpacked = packed[0] + for i in range(1, len(packed)): + unpacked = unpacked, packed[i] + return unpacked + + # We use indices and axis parameters to create idx + # idx is in a form that can be used as a NumPy advanced + # indices for scattering of updates param. in data + idx = [[unpack(numpy.indices(idx_xsection_shape).reshape(indices.ndim - 1, -1)), + indices[tuple(make_slice(indices, axis, i))].reshape(1, -1)[0]] + for i in range(indices.shape[axis])] + idx = list(numpy.concatenate(idx, axis=1)) + idx.insert(axis, idx.pop()) + + # updates_idx is a NumPy advanced indices for indexing + # of elements in the updates + updates_idx = list(idx) + updates_idx.pop(axis) + updates_idx.insert(axis, numpy.repeat(numpy.arange(indices.shape[axis]), + numpy.prod(idx_xsection_shape))) + + scattered = numpy.copy(data) + scattered[tuple(idx)] = updates[tuple(updates_idx)] + return scattered + + +class ScatterElements(OpRun): + + atts = {'axis': 0} + + def __init__(self, onnx_node, desc=None, **options): + OpRun.__init__(self, onnx_node, desc=desc, + **options) + + def _run(self, data, indices, updates): # pylint: disable=W0221 + res = scatter_elements(data, indices, updates, axis=self.axis) + return (res, ) + + def _infer_shapes(self, data, indices, updates): # pylint: disable=W0221 + return (ShapeObject(data.shape, dtype=data.dtype), ) From 07dd71d2147110014284663b29f5c81ad20982d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 21 Feb 2021 22:39:46 +0100 Subject: [PATCH 2/4] lint --- _unittests/ut_onnxrt/test_onnx_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_onnxrt/test_onnx_helper.py b/_unittests/ut_onnxrt/test_onnx_helper.py index 0940334b9..de89c5033 100644 --- a/_unittests/ut_onnxrt/test_onnx_helper.py +++ b/_unittests/ut_onnxrt/test_onnx_helper.py @@ -54,7 +54,7 @@ def test_change_input_first_dimension(self): for inp in model_onnx.graph.input: dim = inp.type.tensor_type.shape.dim[0].dim_value self.assertEqual(dim, 0) - for inp in new_model.graph.input: + for inp in new_model.graph.input: # pylint: disable=E1101 dim = inp.type.tensor_type.shape.dim[0].dim_value self.assertEqual(dim, 2) From 1bbcff181f1e46d0a3b138c3e7f8c4c14a7303e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Sun, 21 Feb 2021 23:04:55 +0100 Subject: [PATCH 3/4] lint --- _unittests/ut_onnxrt/test_shape_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_onnxrt/test_shape_object.py b/_unittests/ut_onnxrt/test_shape_object.py index 533e52fa1..4793dd7f9 100644 --- a/_unittests/ut_onnxrt/test_shape_object.py +++ b/_unittests/ut_onnxrt/test_shape_object.py @@ -72,7 +72,7 @@ def fct2(): st = sh.to_string() self.assertEqual(st, '(1)+(2)') - x, y = sh._args # pylint: disable=W0212 + x, y = sh._args # pylint: disable=W0212,W0632 self.assertEqual(sh._to_string1(x, y), "12") # pylint: disable=W0212 self.assertEqual(sh._to_string2(x, y), "1+2") # pylint: disable=W0212 self.assertEqual(sh._to_string2b( # pylint: disable=W0212 From 6fd8530d96033b15f0f0b0aea382b2a0b4f7db94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?xavier=20dupr=C3=A9?= Date: Mon, 22 Feb 2021 00:03:51 +0100 Subject: [PATCH 4/4] skip one warning --- _unittests/ut_module/test_code_style.py | 4 ++-- mlprodict/testing/test_utils/utils_backend_common.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/_unittests/ut_module/test_code_style.py b/_unittests/ut_module/test_code_style.py index e3bbb668d..19670c3bc 100644 --- a/_unittests/ut_module/test_code_style.py +++ b/_unittests/ut_module/test_code_style.py @@ -16,7 +16,7 @@ def test_style_src(self): check_pep8(src_, fLOG=fLOG, pylint_ignore=('C0103', 'C1801', 'R0201', 'R1705', 'W0108', 'W0613', 'R1702', 'W0212', 'W0640', 'W0223', 'W0201', - 'W0622', 'C0123', 'W0107', + 'W0622', 'C0123', 'W0107', 'R1728', 'C0415', 'R1721', 'C0411'), skip=["Instance of 'tuple' has no ", "do not compare types, use 'isinstance()'", @@ -33,7 +33,7 @@ def test_style_test(self): test = os.path.normpath(os.path.join(thi, "..", )) check_pep8(test, fLOG=fLOG, neg_pattern="temp_.*", pylint_ignore=('C0103', 'C1801', 'R0201', 'R1705', 'W0108', 'W0613', - 'C0111', 'W0107', 'C0415', + 'C0111', 'W0107', 'C0415', 'R1728', 'R1721', 'C0302', 'C0411'), skip=["Instance of 'tuple' has no ", "R1720", diff --git a/mlprodict/testing/test_utils/utils_backend_common.py b/mlprodict/testing/test_utils/utils_backend_common.py index 1d14b04ec..8617ab188 100644 --- a/mlprodict/testing/test_utils/utils_backend_common.py +++ b/mlprodict/testing/test_utils/utils_backend_common.py @@ -149,8 +149,8 @@ def compare_outputs(expected, output, verbose=False, **kwargs): # as one dimension is useless. expected = expected.reshape( tuple([d for d in expected.shape if d > 1])) - output = output.reshape(tuple([d for d in expected.shape - if d > 1])) + output = output.reshape( + tuple([d for d in expected.shape if d > 1])) if NoProb or NoProbOpp: # One vector is (N,) with scores, negative for class 0 # positive for class 1