Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _doc/sphinxdoc/source/api/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ the possibility later to only show a part of a graph.

**benchmark**

.. autosignature:: mlprodict.plotting.validate_graph.plot_validate_benchmark
.. autosignature:: mlprodict.plotting.plot_validate_benchmark

Others
======
Expand Down
21 changes: 20 additions & 1 deletion _unittests/ut_onnxrt/test_onnxrt_python_runtime_.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
OnnxBatchNormalization,
OnnxAcos, OnnxAcosh, OnnxAsin, OnnxAsinh, OnnxAtan, OnnxAtanh,
OnnxAveragePool,
OnnxCast, OnnxCeil, OnnxClip,
OnnxCast, OnnxCastLike, OnnxCeil, OnnxClip,
OnnxCompress,
OnnxConcat, OnnxConv, OnnxConvTranspose,
OnnxConstant, OnnxConstant_9, OnnxConstant_11,
Expand Down Expand Up @@ -1083,6 +1083,25 @@ def test_onnxt_runtime_cast_in(self):

python_tested.append(OnnxCast)

@wraplog()
def test_onnxt_runtime_cast_like(self):
x = numpy.array([1.5, 2.1, 3.1, 4.1]).astype(
numpy.float32) # pylint: disable=E1101
y = numpy.array([1.]).astype(numpy.int64) # pylint: disable=E1101

for opset in range(15, get_opset_number_from_onnx() + 1):
with self.subTest(opset=opset):
onx = OnnxCastLike('X', 'Y', output_names=['Z'],
op_version=opset)
model_def = onx.to_onnx(
{'X': x, 'Y': y},
outputs=[('Z', Int64TensorType())],
target_opset=opset)
got = OnnxInference(model_def).run({'X': x, 'Y': y})
self.assertEqual(x.astype(numpy.int64), got['Z'])

python_tested.append(OnnxCastLike)

@wraplog()
def test_onnxt_runtime_ceil(self):
self.common_test_onnxt_runtime_unary(OnnxCeil, numpy.ceil)
Expand Down
2 changes: 1 addition & 1 deletion mlprodict/onnxrt/ops_cpu/_op_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .op_batch_normalization import BatchNormalization, BatchNormalization_14
from .op_binarizer import Binarizer
from .op_broadcast_gradient_args import BroadcastGradientArgs
from .op_cast import Cast
from .op_cast import Cast, CastLike
from .op_cdist import CDist
from .op_ceil import Ceil
from .op_celu import Celu
Expand Down
26 changes: 26 additions & 0 deletions mlprodict/onnxrt/ops_cpu/op_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,29 @@ def _infer_types(self, x): # pylint: disable=W0221
def _infer_sizes(self, *args, **kwargs):
res = self.run(*args, **kwargs)
return (dict(temp=0), ) + res


class CastLike(OpRun):

def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc, **options)

def _run(self, x, y): # pylint: disable=W0221
if self.inplaces.get(0, False):
return self._run_inplace(x, y)
return (x.astype(y.dtype), )

def _run_inplace(self, x, y):
if x.dtype == y._dtype:
return (x, )
return (x.astype(y.dtype), )

def _infer_shapes(self, x, y): # pylint: disable=W0221
return (x.copy(dtype=y._dtype), )

def _infer_types(self, x, y): # pylint: disable=W0221
return (y._dtype, )

def _infer_sizes(self, *args, **kwargs):
res = self.run(*args, **kwargs)
return (dict(temp=0), ) + res