This repository has been archived by the owner on Jan 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implements numpy functions with onnx (#214)
* Implements numpy functions with onnx * finalize OnnxNumpyCompiler * First sketch of easy function onnx numpy * add operator +, /, *, - * support constants * simplify when variable are reused * add function transformer * Fix issue with RNN opset 14 * Update requirements.txt
- Loading branch information
Showing
36 changed files
with
1,259 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ API | |
onnxrt_ops | ||
onnx_conv | ||
sklapi | ||
npy | ||
asv | ||
validation | ||
testing | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
|
||
.. _l-numpy-onnxpy: | ||
|
||
Numpy revisited with ONNX | ||
========================= | ||
|
||
Converting custom code into :epkg:`ONNX` is not necessarily easy. | ||
One big obstacle is :epkg:`ONNX` does not represent all numpy functions | ||
with a single operator. One possible option is to provide a | ||
:epkg:`numpy` API to :epkg:`ONNX`. That's the purpose of wrapper | ||
:class:`onnxnumpy <mlprodict.npy.onnx_numpy_wrapper.onnxnumpy>`. | ||
It takes a function written with functions following the same | ||
signature as :epkg:`numpy` and provides a way to execute them | ||
with an :epkg:`ONNX` runtime. In the below example, | ||
`custom_fct` creates an :epkg:`ONNX` graph, the wrapper | ||
loads it in a runtime and runs it everytime the function | ||
is called. | ||
|
||
.. runpython:: | ||
:showcode: | ||
|
||
import numpy | ||
from typing import Any | ||
from mlprodict.npy import onnxnumpy_default, NDArray | ||
import mlprodict.npy.numpy_impl as nxnp | ||
|
||
@onnxnumpy_default | ||
def custom_fct(x: NDArray[Any, numpy.float32], | ||
) -> NDArray[Any, numpy.float32]: | ||
"onnx numpy abs" | ||
return nxnp.abs(x) + numpy.float32(1) | ||
|
||
x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) | ||
y = custom_fct(x) | ||
print(y) | ||
|
||
Annotations are mandatory to indicate inputs and outputs type. | ||
As a result, the returned function is strict about types | ||
as opposed to numpy. This approach is similar to what | ||
:epkg:`tensorflow` with `autograph | ||
<https://www.tensorflow.org/api_docs/python/tf/autograph>`_. | ||
|
||
.. contents:: | ||
:local: | ||
|
||
NDArray | ||
+++++++ | ||
|
||
.. autosignature:: mlprodict.npy.onnx_numpy_compiler.NDArray | ||
:members: | ||
|
||
onnxnumpy | ||
+++++++++ | ||
|
||
.. autosignature:: mlprodict.npy.onnx_numpy_wrapper.onnxnumpy | ||
|
||
.. autosignature:: mlprodict.npy.onnx_numpy_wrapper.onnxnumpy_default | ||
|
||
OnnxNumpyCompiler | ||
+++++++++++++++++ | ||
|
||
.. autosignature:: mlprodict.npy.onnx_numpy_compiler.OnnxNumpyCompiler | ||
:members: | ||
|
||
OnnxVar | ||
+++++++ | ||
|
||
.. autosignature:: mlprodict.npy.onnx_variable.OnnxVar | ||
:members: | ||
|
||
Available numpy functions | ||
+++++++++++++++++++++++++ | ||
|
||
.. autosignature:: mlprodict.npy.numpy_impl.abs | ||
|
||
.. autosignature:: mlprodict.npy.numpy_impl.sum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
@brief test tree node (time=2s) | ||
@brief test tree node (time=15s) | ||
""" | ||
import unittest | ||
import math | ||
|
2 changes: 1 addition & 1 deletion
2
_unittests/ut__skl2onnx/test_sklearn_gaussian_mixture_converter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
@brief test tree node (time=2s) | ||
@brief test tree node (time=30s) | ||
""" | ||
import unittest | ||
import numpy as np | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
_unittests/ut__skl2onnx/test_sklearn_glm_regressor_converter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
@brief test log(time=10s) | ||
@brief test log(time=9s) | ||
""" | ||
|
||
import unittest | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
@brief test tree node (time=3s) | ||
@brief test tree node (time=14s) | ||
""" | ||
import unittest | ||
import numpy | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
@brief test tree node (time=2s) | ||
@brief test tree node (time=8s) | ||
""" | ||
import unittest | ||
import numpy | ||
|
2 changes: 1 addition & 1 deletion
2
_unittests/ut__skl2onnx/test_sklearn_label_encoder_converter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
@brief test tree node (time=2s) | ||
@brief test tree node (time=5s) | ||
""" | ||
import unittest | ||
import numpy | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
@brief test tree node (time=2s) | ||
@brief test tree node (time=12s) | ||
""" | ||
import unittest | ||
import numpy | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
""" | ||
@brief test tree node (time=2s) | ||
@brief test tree node (time=10s) | ||
""" | ||
import unittest | ||
import warnings | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@brief test log(time=3s) | ||
""" | ||
import unittest | ||
import warnings | ||
from logging import getLogger | ||
from typing import Any | ||
import numpy | ||
from sklearn.preprocessing import FunctionTransformer | ||
from pyquickhelper.pycode import ExtTestCase, ignore_warnings | ||
from mlprodict.onnx_conv import register_rewritten_operators, to_onnx | ||
from mlprodict.onnxrt import OnnxInference | ||
from mlprodict.npy import onnxnumpy_default | ||
import mlprodict.npy.numpy_impl as nxnp | ||
from mlprodict.npy import NDArray | ||
|
||
|
||
@onnxnumpy_default | ||
def custom_fct(x: NDArray[Any, numpy.float32], | ||
) -> NDArray[Any, numpy.float32]: | ||
"onnx custom function" | ||
return (nxnp.abs(x) + x) / numpy.float32(2) | ||
|
||
|
||
class TestOnnxFunctionTransformer(ExtTestCase): | ||
|
||
def setUp(self): | ||
logger = getLogger('skl2onnx') | ||
logger.disabled = True | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", ResourceWarning) | ||
res = register_rewritten_operators() | ||
self.assertGreater(len(res), 2) | ||
self.assertIn('SklearnFunctionTransformer', res[0]) | ||
self.assertIn('SklearnFunctionTransformer', res[1]) | ||
|
||
@ignore_warnings(DeprecationWarning) | ||
def test_function_transformer(self): | ||
x = numpy.array([[6.1, -5], [3.5, -7.8]], dtype=numpy.float32) | ||
tr = FunctionTransformer(custom_fct) | ||
tr.fit(x) | ||
y_exp = tr.transform(x) | ||
self.assertEqualArray( | ||
numpy.array([[6.1, 0.], [3.5, 0.]], dtype=numpy.float32), | ||
y_exp) | ||
|
||
onnx_model = to_onnx(tr, x) | ||
oinf = OnnxInference(onnx_model) | ||
y_onx = oinf.run({'X': x}) | ||
self.assertEqualArray(y_exp, y_onx['variable']) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.