Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion _unittests/ut_onnxrt/test_onnxrt_python_runtime_.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
except ImportError:
from sklearn.utils.testing import ignore_warnings
from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
OnnxAbs, OnnxAdd, OnnxArgMax, OnnxArgMin,
OnnxAbs, OnnxAdd, OnnxArgMax, OnnxArgMin, OnnxAtan,
OnnxBatchNormalization,
OnnxConcat,
OnnxCeil, OnnxClip, OnnxConstant, OnnxConstantOfShape,
Expand Down Expand Up @@ -396,6 +396,28 @@ def test_onnxt_runtime_argmin_12(self):
self.assertEqualArray(numpy.array([2, 1], dtype=numpy.int64),
got['Y'], decimal=6)

def test_onnxt_runtime_atan(self):
self.common_test_onnxt_runtime_unary(OnnxAtan, numpy.arctan)

def test_onnxt_runtime_atan2(self):
test_pairs = [[y, x] for x in [3., -4., 0.] for y in [5., -6., 0.]]
y_val = numpy.array([y for y, x in test_pairs], dtype=numpy.float32)
x_val = numpy.array([x for y, x in test_pairs], dtype=numpy.float32)

def atan2(y, x):
# size: 100000
# timeit arctan: 0.00205
# timeit arctan2: 0.00361
# timeit atan2: 0.00599
sx = numpy.sign(x)
sy = numpy.sign(y)
pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-numpy.pi / 2)
atan_part = numpy.arctan(y / (x + (1 - sx ** 2))) * sx ** 2
return atan_part + pi_part

self.assertEqualArray(
numpy.arctan2(y_val, x_val), atan2(y_val, x_val))

def test_onnxt_runtime_batch_normalization(self):
# input size: (1, 2, 1, 3)
x = numpy.array([[[[-1, 0, 1]], [[2, 3, 4]]]]).astype(numpy.float32)
Expand Down
19 changes: 15 additions & 4 deletions mlprodict/onnxrt/onnx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,9 @@ def _build_compile_run(self):
# to the onnx graph
print(oinf2)
"""
def clean_name(name):
return name.replace(":", "_").replace('.', '_')

# inits
code = ['def compiled_run(dict_inputs):']
context = {}
Expand All @@ -1018,25 +1021,33 @@ def _build_compile_run(self):
inputs = self.input_names
code.append(" # inputs")
for inp in inputs:
code.append(" {0} = dict_inputs['{0}']".format(inp))
code.append(" {0} = dict_inputs['{1}']".format(
clean_name(inp), inp))

# code
for i, node in enumerate(self.sequence_):
name = "n{}_{}".format(i, node.ops_.__class__.__name__.lower())
context[name] = node.ops_._run
code.append(' ({1}, ) = {2}({0})'.format(
', '.join(node.inputs), ', '.join(node.outputs), name))
', '.join(map(clean_name, node.inputs)),
', '.join(map(clean_name, node.outputs)),
name))

# return
code.append(' return {')
for out in self.output_names:
code.append(" '{0}': {0},".format(out))
code.append(" '{1}': {0},".format(
clean_name(out), out))
code.append(' }')
final_code = '\n'.join(code)

# compile the outcome
context['self'] = self
obj = compile(final_code, "<string>", 'exec')
try:
obj = compile(final_code, "<string>", 'exec')
except SyntaxError as e:
raise SyntaxError(
"Unable to compile\n#####\n{}".format(final_code)) from e
fcts_obj = [_ for _ in obj.co_consts
if _ is not None and not isinstance(_, (bool, str, int))]
fct = make_callable("compiled_run", fcts_obj[0], final_code, context)
Expand Down
1 change: 1 addition & 0 deletions mlprodict/onnxrt/ops_cpu/_op_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .op_argmax import ArgMax
from .op_argmin import ArgMin
from .op_array_feature_extractor import ArrayFeatureExtractor
from .op_atan import Atan
from .op_batch_normalization import BatchNormalization
from .op_binarizer import Binarizer
from .op_cast import Cast
Expand Down
26 changes: 26 additions & 0 deletions mlprodict/onnxrt/ops_cpu/op_atan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
@file
@brief Runtime operator.
"""
import numpy
from ._op import OpRunUnaryNum


class Atan(OpRunUnaryNum):

def __init__(self, onnx_node, desc=None, **options):
OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
**options)

def _run(self, x): # pylint: disable=W0221
if self.inplaces.get(0, False):
return self._run_inplace(x)
return (numpy.arctan(x), )

def _run_inplace(self, x):
return (numpy.arctan(x, out=x), )

def to_python(self, inputs):
return self._to_python_numpy(inputs, 'arctan')