Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 112 additions & 111 deletions _doc/notebooks/numpy_api_onnx_ftr.ipynb

Large diffs are not rendered by default.

272 changes: 132 additions & 140 deletions _doc/notebooks/onnx_fft.ipynb

Large diffs are not rendered by default.

383 changes: 210 additions & 173 deletions _unittests/ut_npy/test_onnx_variable.py

Large diffs are not rendered by default.

310 changes: 155 additions & 155 deletions _unittests/ut_npy/test_onnx_variable_ort.py

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions _unittests/ut_npy/test_onnx_variable_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ def common_test_abs_topk(x):


@onnxnumpy_default
def test_abs_topk(x: NDArray[Any, numpy.float32],
) -> (NDArray[Any, numpy.float32],
NDArray[Any, numpy.int64]):
def otest_abs_topk(x: NDArray[Any, numpy.float32],
) -> (NDArray[Any, numpy.float32],
NDArray[Any, numpy.int64]):
"onnx topk"
return common_test_abs_topk(x)


@onnxnumpy(runtime='onnxruntime1')
def test_abs_topk_ort(x: NDArray[Any, numpy.float32],
) -> (NDArray[Any, numpy.float32],
NDArray[Any, numpy.int64]):
def otest_abs_topk_ort(x: NDArray[Any, numpy.float32],
) -> (NDArray[Any, numpy.float32],
NDArray[Any, numpy.int64]):
"onnx topk"
return common_test_abs_topk(x)

Expand All @@ -51,8 +51,8 @@ class TestOnnxVariableTuple(ExtTestCase):
def test_py_abs_topk(self):
x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0],
dtype=numpy.float32).reshape((-1, 2))
y, yi = test_abs_topk(x) # pylint: disable=E0633
self.assertIn('output: "y"', str(test_abs_topk.compiled.onnx_))
y, yi = otest_abs_topk(x) # pylint: disable=E0633
self.assertIn('output: "y"', str(otest_abs_topk.compiled.onnx_))
exp_y = numpy.array([[6.1, 7.8, 6.7]], dtype=numpy.float32).T
exp_yi = numpy.array([[0, 1, 0]], dtype=numpy.float32).T
self.assertEqualArray(exp_y, y)
Expand All @@ -62,7 +62,7 @@ def test_py_abs_topk(self):
def test_py_abs_topk_ort(self):
x = numpy.array([6.1, -5, 3.5, -7.8, 6.7, -5.0],
dtype=numpy.float32).reshape((-1, 2))
y, yi = test_abs_topk_ort(x) # pylint: disable=E0633
y, yi = otest_abs_topk_ort(x) # pylint: disable=E0633
exp_y = numpy.array([[6.1, 7.8, 6.7]], dtype=numpy.float32).T
exp_yi = numpy.array([[0, 1, 0]], dtype=numpy.float32).T
self.assertEqualArray(exp_y, y)
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_npy/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,5 +375,5 @@ def test_signature_optional3_kwargs_more(self):


if __name__ == "__main__":
TestWrappers().test_signature_optional_errors_runtime()
# TestWrappers().test_signature_optional_errors_runtime()
unittest.main()
17 changes: 17 additions & 0 deletions mlprodict/npy/onnx_numpy_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,23 @@ def _to_onnx(self, op_version=None, signature=None, version=None):
type(self), self.fct_))
return self.onnx_

def to_onnx(self, **kwargs):
"""
Returns the ONNX graph for the wrapped function.
It takes additional arguments to distinguish between multiple graphs.
This happens when a function needs to support multiple type.

:return: ONNX graph
"""
if len(kwargs) > 0:
raise NotImplementedError( # pragma: no cover
"kwargs is not empty, this case is not implemented. "
"kwargs=%r." % kwargs)
if hasattr(self, 'onnx_'):
return self.onnx_
raise NotImplementedError( # pragma: no cover
"Attribute 'onnx_' is missing.")

def _build_runtime(self, op_version=None, runtime=None,
signature=None, version=None):
"""
Expand Down
48 changes: 48 additions & 0 deletions mlprodict/npy/onnx_numpy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ def __setstate__(self, state):
"""
self.compiled = state['compiled']

def to_onnx(self, **kwargs):
"""
Returns the ONNX graph for the wrapped function.
It takes additional arguments to distinguish between multiple graphs.
This happens when a function needs to support multiple type.

:return: ONNX graph
"""
return self.compiled.to_onnx(**kwargs)


def onnxnumpy(op_version=None, runtime=None, signature=None):
"""
Expand Down Expand Up @@ -202,6 +212,44 @@ def _populate(self, version):
def _validate_onnx_data(self, X):
return X

def to_onnx(self, **kwargs):
"""
Returns the ONNX graph for the wrapped function.
It takes additional arguments to distinguish between multiple graphs.
This happens when a function needs to support multiple type.

:return: ONNX graph
"""
if len(self.signed_compiled) == 0:
raise RuntimeError( # pragma: no cover
"No ONNX graph was compiled.")
if len(kwargs) == 0 and len(self.signed_compiled) == 1:
# We take the only one.
key = list(self.signed_compiled)[0]
cpl = self.signed_compiled[key]
return cpl.to_onnx()
if len(kwargs) == 0:
raise ValueError(
"There are multiple compiled ONNX graphs associated "
"with keys %r (add key=...)." % list(self.signed_compiled))
if list(kwargs) != ['key']:
raise ValueError(
"kwargs should contain one parameter key=... but "
"it is %r." % kwargs)
key = kwargs['key']
if key in self.signed_compiled:
return self.signed_compiled[key].compiled.onnx_
found = []
for k, v in self.signed_compiled.items():
if k.args == key or (
not isinstance(key, tuple) and k.args == (key, )):
found.append((k, v))
if len(found) == 1:
return found[0][1].compiled.onnx_
raise ValueError(
"Unable to find signature with key=%r among %r." % (
key, list(self.signed_compiled)))


def onnxnumpy_np(op_version=None, runtime=None, signature=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion mlprodict/npy/onnx_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def as_tuple_with_sep(self, sep):
(tuple() if self.kwargs is None else self.kwargs))

def as_string(self):
"Returns a single stirng identifier."
"Returns a single string identifier."
val = "_".join(map(str, self.as_tuple_with_sep("_")))
val = val.replace("<class 'numpy.", "").replace(
'.', "_").replace("'>", "").replace(" ", "")
Expand Down