Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion _unittests/ut_onnxrt/test_onnxrt_python_runtime_.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
OnnxRelu, OnnxReshape,
OnnxRound,
OnnxScatterElements,
OnnxSequenceConstruct,
OnnxSequenceAt, OnnxSequenceConstruct,
OnnxShape, OnnxSlice, OnnxSigmoid, OnnxSign,
OnnxSin, OnnxSinh,
OnnxSize, OnnxSoftmax,
Expand Down Expand Up @@ -3424,6 +3424,26 @@ def test_onnxt_runtime_scatter_elements2(self):
got = OnnxInference(model_def).run({'X': x})
self.assertEqualArray(y, got['Y'])

@wraplog()
def test_onnxt_runtime_sequence_at(self):
x = numpy.random.randn(20, 2).astype( # pylint: disable=E1101
numpy.float32) # pylint: disable=E1101
onx = OnnxSequenceAt(
OnnxSequenceConstruct(
'X', 'X', 'X',
op_version=get_opset_number_from_onnx()),
numpy.array(1, dtype=numpy.int64),
op_version=get_opset_number_from_onnx(),
output_names=['Y'])

model_def = onx.to_onnx({'X': x.astype(numpy.float32)},
target_opset=get_opset_number_from_onnx())
oinf = OnnxInference(model_def)
got = oinf.run({'X': x})
output = got['Y']
self.assertEqualArray(x, output)
python_tested.append(OnnxSequenceAt)

@wraplog()
def test_onnxt_runtime_sequence_construct(self):
x = numpy.random.randn(20, 2).astype( # pylint: disable=E1101
Expand Down
3 changes: 2 additions & 1 deletion _unittests/ut_tools/test_export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,8 @@ def test_export_einsum(self):
with self.subTest(rt='onnxruntime1'):
opts = SessionOptions()
opts.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
oinf = OnnxInference(onx, runtime='onnxruntime1', runtime_options=opts)
oinf = OnnxInference(
onx, runtime='onnxruntime1', runtime_options=opts)
rr = oinf.run({'X1': x1, 'X2': x2, 'X3': x3})
self.assertEqualArray(r, rr['Y'])
with self.subTest(rt='python'):
Expand Down
1 change: 1 addition & 0 deletions mlprodict/onnxrt/ops_cpu/_op_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .op_scaler import Scaler
from .op_scan import Scan
from .op_scatter_elements import ScatterElements
from .op_sequence_at import SequenceAt
from .op_sequence_construct import SequenceConstruct
from .op_sequence_insert import SequenceInsert
from .op_shape import Shape
Expand Down
32 changes: 32 additions & 0 deletions mlprodict/onnxrt/ops_cpu/op_sequence_at.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
@file
@brief Runtime operator.

.. versionadded:: 0.8
"""
from ._op import OpRun
from ..shape_object import ShapeObject


class SequenceAt(OpRun):

atts = {}

def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc,
atts=SequenceAt.atts, **options)

def _run(self, seq, index): # pylint: disable=W0221
return (seq[index], )

def _infer_shapes(self, seq, index): # pylint: disable=W0221
return (ShapeObject(None, dtype=seq.subtype.dtype), )

def _infer_types(self, *data): # pylint: disable=W0221
return (None, )

def _infer_sizes(self, *args): # pylint: disable=W0221
res = self.run(*args)
return (dict(temp=0), ) + res
2 changes: 1 addition & 1 deletion mlprodict/onnxrt/ops_cpu/op_sequence_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _run(self, *data): # pylint: disable=W0221
return (data, )

def _infer_shapes(self, *data): # pylint: disable=W0221
return (ShapeObject(None, dtype="sequence"), )
return (ShapeObject(None, dtype="sequence", subtype=data[0]), )

def _infer_types(self, *data): # pylint: disable=W0221
return (list, )
Expand Down
9 changes: 7 additions & 2 deletions mlprodict/onnxrt/shape_object.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=C0302
"""
@file
@brief Shape object.
Expand Down Expand Up @@ -447,14 +448,17 @@ class ShapeObject(BaseDimensionShape):
print(mx.evaluate(n=4))
"""

def __init__(self, shape, dtype=None, use_n1=False, name=None):
def __init__(self, shape, dtype=None, use_n1=False, name=None,
subtype=None):
"""
@param shape tuple or `numpy.array`
@param dtype dtype
@param use_n1 use `'n'` if the first dimension is unknown
@param name optional, for debugging purposes
@param subtype element type if this type is a list
"""
self.name = name
self.subtype = subtype
if isinstance(shape, numpy.ndarray):
self._shape = [DimensionObject(s) for s in shape.shape]
self._dtype = shape.dtype
Expand Down Expand Up @@ -590,7 +594,8 @@ def copy(self, dtype=None, name=None):
return ShapeObject(None, dtype=self.dtype, name=name or self.name)
return ShapeObject(self._shape.copy(),
self.dtype if dtype is None else dtype,
name=name or self.name)
name=name or self.name,
subtype=self.subtype)

def __getitem__(self, index):
"""
Expand Down