Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions _unittests/ut_module/test_code_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_style_src(self):
check_pep8(src_, fLOG=fLOG,
pylint_ignore=('C0103', 'C1801', 'R0201', 'R1705', 'W0108', 'W0613',
'R1702', 'W0212', 'W0640', 'W0223', 'W0201',
'W0622', 'C0123', 'W0107',
'W0622', 'C0123', 'W0107', 'R1728',
'C0415', 'R1721', 'C0411'),
skip=["Instance of 'tuple' has no ",
"do not compare types, use 'isinstance()'",
Expand All @@ -33,7 +33,7 @@ def test_style_test(self):
test = os.path.normpath(os.path.join(thi, "..", ))
check_pep8(test, fLOG=fLOG, neg_pattern="temp_.*",
pylint_ignore=('C0103', 'C1801', 'R0201', 'R1705', 'W0108', 'W0613',
'C0111', 'W0107', 'C0415',
'C0111', 'W0107', 'C0415', 'R1728',
'R1721', 'C0302', 'C0411'),
skip=["Instance of 'tuple' has no ",
"R1720",
Expand Down
34 changes: 33 additions & 1 deletion _unittests/ut_npy/test_onnx_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,28 @@ def test_abs_not(x: NDArray[Any, numpy.float32],
@onnxnumpy_default
def test_abs_filter(x: NDArray[Any, numpy.float32],
) -> NDArray[Any, numpy.float32]:
"onnx numpy not"
"onnx numpy filter"
return nxnp.abs(x)[x[:, 0] > numpy.float32(15)]


@onnxnumpy_default
def test_abs_set2(x: NDArray[Any, numpy.float32],
) -> NDArray[Any, numpy.bool]:
"onnx numpy set"
temp = nxnp.abs(x).copy()
temp[:2, 0] = numpy.float32(-1)
return temp


@onnxnumpy_default
def test_abs_set3(x: NDArray[Any, numpy.float32],
) -> NDArray[Any, numpy.bool]:
"onnx numpy set"
temp = nxnp.abs(x).copy()
temp[:2, :1] = numpy.array([[-1.5, -1.5]], dtype=numpy.float32).T
return temp


class TestOnnxVariable(ExtTestCase):

def test_onnx_variable_abs(self):
Expand Down Expand Up @@ -414,6 +432,20 @@ def test_onnx_variable_abs_filter(self):
y = test_abs_filter(x)
self.assertEqualArray(y, numpy.abs(x)[x[:, 0] > 15])

def test_onnx_variable_abs_set(self):
x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
y = test_abs_set2(x)
temp = numpy.abs(x)
temp[:, 0] = -1
self.assertEqualArray(y, temp)

def test_onnx_variable_abs_set3(self):
x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32)
y = test_abs_set3(x)
temp = numpy.abs(x)
temp[:, 0] = -1.5
self.assertEqualArray(y, temp)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion _unittests/ut_onnxrt/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_change_input_first_dimension(self):
for inp in model_onnx.graph.input:
dim = inp.type.tensor_type.shape.dim[0].dim_value
self.assertEqual(dim, 0)
for inp in new_model.graph.input:
for inp in new_model.graph.input: # pylint: disable=E1101
dim = inp.type.tensor_type.shape.dim[0].dim_value
self.assertEqual(dim, 2)

Expand Down
47 changes: 46 additions & 1 deletion _unittests/ut_onnxrt/test_onnxrt_python_runtime_.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
OnnxReduceSum, OnnxReduceSumApi11, OnnxReduceSum_11, OnnxReduceSum_1,
OnnxReduceSumSquare,
OnnxRelu, OnnxReshape,
OnnxShape, OnnxSlice, OnnxSigmoid, OnnxSign, OnnxSin,
OnnxScatterElements, OnnxShape, OnnxSlice, OnnxSigmoid, OnnxSign, OnnxSin,
OnnxSplitApi11,
OnnxSoftmax, OnnxSplit,
OnnxSqrt, OnnxSub, OnnxSum,
Expand Down Expand Up @@ -2320,6 +2320,51 @@ def test_onnxt_runtime_reshape(self):
self.assertEqualArray(exp, got['Y'])
python_tested.append(OnnxReshape)

@wraplog()
def test_onnxt_runtime_scatter_elements1(self):
for opset in [11, get_opset_number_from_onnx()]:
if opset > get_opset_number_from_onnx():
continue
with self.subTest(opset=opset):
data = numpy.array(
[[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=numpy.float32)
indices = numpy.array([[1, 3]], dtype=numpy.int64)
updates = numpy.array([[1.1, 2.1]], dtype=numpy.float32)
output = numpy.array(
[[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=numpy.float32)

onx = OnnxScatterElements(
'X', indices, updates, axis=1,
output_names=['Y'], op_version=opset)
model_def = onx.to_onnx(
{'X': data}, target_opset=opset)
got = OnnxInference(model_def).run({'X': data})
self.assertEqualArray(output, got['Y'])
python_tested.append(OnnxScatterElements)

@wraplog()
def test_onnxt_runtime_scatter_elements2(self):
for opset in [11, get_opset_number_from_onnx()]:
if opset > get_opset_number_from_onnx():
continue
with self.subTest(opset=opset):
x = numpy.arange(20).reshape((4, 5)).astype( # pylint: disable=E1101
numpy.float32) # pylint: disable=E1101
indices = numpy.array([[1, 1, 1, 1]], dtype=numpy.int64).T
updates = numpy.array(
[[-1, -1, -1, -1]], dtype=numpy.float32).T
y = x.copy()
y[:, 1] = -1

onx = OnnxScatterElements(
'X', indices, updates, axis=1,
output_names=['Y'], op_version=opset)
model_def = onx.to_onnx(
{'X': x, 'indices': indices, 'updates': updates},
target_opset=opset)
got = OnnxInference(model_def).run({'X': x})
self.assertEqualArray(y, got['Y'])

@wraplog()
def test_onnxt_runtime_shape(self):
x = numpy.random.randn(20, 2).astype( # pylint: disable=E1101
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_onnxrt/test_shape_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def fct2():
st = sh.to_string()
self.assertEqual(st, '(1)+(2)')

x, y = sh._args # pylint: disable=W0212
x, y = sh._args # pylint: disable=W0212,W0632
self.assertEqual(sh._to_string1(x, y), "12") # pylint: disable=W0212
self.assertEqual(sh._to_string2(x, y), "1+2") # pylint: disable=W0212
self.assertEqual(sh._to_string2b( # pylint: disable=W0212
Expand Down
100 changes: 98 additions & 2 deletions mlprodict/npy/onnx_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,28 @@
OnnxDiv,
OnnxEqual,
OnnxGather, OnnxGreater,
OnnxIdentity,
OnnxLess,
OnnxMatMul, OnnxMod, OnnxMul,
OnnxNeg, OnnxNot,
OnnxOr,
OnnxPow,
OnnxReduceSum, OnnxReshape,
OnnxSlice, OnnxSqueeze, OnnxSub,
OnnxScatterElements, OnnxSlice, OnnxSqueeze, OnnxSub,
OnnxTopK, OnnxTranspose
)


try:
numpy_bool = numpy.bool_
except AttributeError:
numpy_bool = bool
try:
numpy_str = numpy.str
except AttributeError:
numpy_str = str


class OnnxVar:
"""
Variables used into :epkg:`onnx` computation.
Expand Down Expand Up @@ -61,19 +72,21 @@ def to_algebra(self, op_version=None):
if not hasattr(self, 'alg_'):
raise RuntimeError( # pragma: no cover
"Missing attribute 'alg_'.")
self.alg_ = alg
return alg

new_inputs = []
for inp in self.inputs:
if isinstance(inp, (
int, float, str, numpy.ndarray, numpy.int32,
numpy.int64, numpy.float32, numpy.float64,
numpy.bool_, numpy.str, numpy.int8, numpy.uint8,
numpy_bool, numpy_str, numpy.int8, numpy.uint8,
numpy.int16, numpy.uint16, numpy.uint32, numpy.uint64)):
new_inputs.append(inp)
else:
new_inputs.append(
inp.to_algebra(op_version=op_version))

res = self.onnx_op(*new_inputs, op_version=op_version,
**self.onnx_op_kwargs)
if self.select_output is None:
Expand Down Expand Up @@ -244,3 +257,86 @@ def __getitem__(self, index):
if steps is None:
return OnnxVar(self, starts, ends, axes, op=OnnxSlice)
return OnnxVar(self, starts, ends, axes, steps, op=OnnxSlice)

def _matrix_multiply(self, indices, axis):
"""
Creates a matrix.
"""
shapes = tuple(i[1] - i[0] for i in indices)
mat = numpy.empty(shapes, dtype=numpy.int64)
ind = [slice(None) for i in indices]
ind[axis] = numpy.arange(0, indices[axis][1] - indices[axis][0])
values = numpy.arange(indices[axis][0], indices[axis][1])
mat[ind] = values
return mat

def __setitem__(self, index, value):
"""
Deals with multiple scenarios.
* *index* is an integer or a slice, a tuple of integers and slices,
example: `[0, 1]`, `[:5, :6]`, `[::2]` (**scenario 1**)
* *index* is an *ONNX* object (more precisely an instance of
@see cl OnnxVar), then the method assumes it is an array of
boolean to select a subset of the tensor along the first axis,
example: `mat[mat == 0]` (**scenario 2**)
This processing is applied before the operator it contains.
A copy should be made (Identity node or copy method).
"""
if self.onnx_op is not None and self.onnx_op is not OnnxIdentity:
raise RuntimeError(
"A copy should be made before setting new values on a matrix. "
"Method copy() would do that.")
if isinstance(index, OnnxVar):
# scenario 2
raise NotImplementedError()

if not isinstance(index, tuple):
index = (index, )

# scenario 1
indices = []
for d, ind in enumerate(index):
if isinstance(ind, int):
indices.append((ind, ind + 1))
elif isinstance(ind, slice):
if ind.step is not None:
raise NotImplementedError(
"Unable to assign new values with step defined "
"on dimension %r." % d)
start = 0 if ind.start is None else ind.start
if ind.stop is None:
raise NotImplementedError(
"Unable to assign new values with end undefined "
"on dimension %r." % d)
stop = ind.stop
indices.append((start, stop))
else:
raise NotImplementedError(
"Unable to assign new values due to unexpected type %r "
"on dimension %r." % (type(ind), d))

axis = len(index) - 1
mat_indices = self._matrix_multiply(indices, axis)

if isinstance(value, (OnnxVar, numpy.ndarray)):
mat_updates = value
elif isinstance(value, (numpy.float32, numpy.float64, numpy.int32,
numpy.int64, numpy.uint32, numpy.uint64,
numpy_bool, numpy_str)):
mat_updates = numpy.full(
mat_indices.shape, value, dtype=value.dtype)
else:
raise NotImplementedError(
"Unable to assign new values due to unexpected type %r "
"for value." % type(value))

self.inputs = [
OnnxVar(self.inputs[0], mat_indices, mat_updates, op=OnnxScatterElements,
axis=axis)]
return self

def copy(self):
"""
Returns a copy of self (use of Identity node).
"""
return OnnxVar(self, op=OnnxIdentity)
1 change: 1 addition & 0 deletions mlprodict/onnxrt/ops_cpu/_op_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from .op_rnn import RNN
from .op_scaler import Scaler
from .op_scan import Scan
from .op_scatter_elements import ScatterElements
from .op_shape import Shape
from .op_sigmoid import Sigmoid
from .op_sign import Sign
Expand Down
63 changes: 63 additions & 0 deletions mlprodict/onnxrt/ops_cpu/op_scatter_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
@file
@brief Runtime operator.
"""
import numpy
from ..shape_object import ShapeObject
from ._op import OpRun


def scatter_elements(data, indices, updates, axis=0):
if axis < 0:
axis = data.ndim + axis

idx_xsection_shape = indices.shape[:axis] + indices.shape[axis + 1:]

def make_slice(arr, axis, i):
slc = [slice(None)] * arr.ndim
slc[axis] = i
return slc

def unpack(packed):
unpacked = packed[0]
for i in range(1, len(packed)):
unpacked = unpacked, packed[i]
return unpacked

# We use indices and axis parameters to create idx
# idx is in a form that can be used as a NumPy advanced
# indices for scattering of updates param. in data
idx = [[unpack(numpy.indices(idx_xsection_shape).reshape(indices.ndim - 1, -1)),
indices[tuple(make_slice(indices, axis, i))].reshape(1, -1)[0]]
for i in range(indices.shape[axis])]
idx = list(numpy.concatenate(idx, axis=1))
idx.insert(axis, idx.pop())

# updates_idx is a NumPy advanced indices for indexing
# of elements in the updates
updates_idx = list(idx)
updates_idx.pop(axis)
updates_idx.insert(axis, numpy.repeat(numpy.arange(indices.shape[axis]),
numpy.prod(idx_xsection_shape)))

scattered = numpy.copy(data)
scattered[tuple(idx)] = updates[tuple(updates_idx)]
return scattered


class ScatterElements(OpRun):

atts = {'axis': 0}

def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc,
**options)

def _run(self, data, indices, updates): # pylint: disable=W0221
res = scatter_elements(data, indices, updates, axis=self.axis)
return (res, )

def _infer_shapes(self, data, indices, updates): # pylint: disable=W0221
return (ShapeObject(data.shape, dtype=data.dtype), )
4 changes: 2 additions & 2 deletions mlprodict/testing/test_utils/utils_backend_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def compare_outputs(expected, output, verbose=False, **kwargs):
# as one dimension is useless.
expected = expected.reshape(
tuple([d for d in expected.shape if d > 1]))
output = output.reshape(tuple([d for d in expected.shape
if d > 1]))
output = output.reshape(
tuple([d for d in expected.shape if d > 1]))
if NoProb or NoProbOpp:
# One vector is (N,) with scores, negative for class 0
# positive for class 1
Expand Down