Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
fix failing unittest, onnx grammar correctly interprets unary operator -
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Sep 19, 2020
1 parent 45301ae commit eea3ba4
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
7 changes: 4 additions & 3 deletions _unittests/ut_onnx_grammar/test_onnx_grammar_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,18 @@ def trs(x):
onnx_code = translate_fct2onnx(
trs, context={'numpy.transpose': numpy.transpose},
output_names=['Z'])
print(onnx_code)

fct = translate_fct2onnx(
trs, context=None, cpl=True, output_names=['Z'])
self.assertTrue(callable(fct))

from skl2onnx.algebra.onnx_ops import ( # pylint: disable=E0611
OnnxAdd, OnnxTranspose, OnnxMul, OnnxIdentity)
OnnxAdd, OnnxTranspose, OnnxMul, OnnxIdentity,
OnnxNeg)
ctx = {'OnnxAdd': OnnxAdd,
'OnnxTranspose': OnnxTranspose,
'OnnxMul': OnnxMul, 'OnnxIdentity': OnnxIdentity}
'OnnxMul': OnnxMul, 'OnnxIdentity': OnnxIdentity,
'OnnxNeg': OnnxNeg}

fct = translate_fct2onnx(
trs, context={'numpy.transpose': numpy.transpose},
Expand Down
8 changes: 6 additions & 2 deletions _unittests/ut_sklapi/test_onnx_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
from skl2onnx.algebra.onnx_ops import OnnxMul # pylint: disable=E0611
from pyquickhelper.pycode import ExtTestCase, skipif_appveyor
from pyquickhelper.pycode import ExtTestCase, skipif_appveyor, ignore_warnings
from mlprodict.sklapi import OnnxTransformer
from mlprodict.tools import get_opset_number_from_onnx

Expand Down Expand Up @@ -81,8 +81,10 @@ def test_multiple_transform(self):
self.assertNotEmpty(res)
for _, tr in res:
tr.fit()
self.assertRaise(lambda tr=tr: tr.transform(x), KeyError)
self.assertRaise(lambda tr=tr: tr.transform(x),
(KeyError, RuntimeError))

@ignore_warnings(DeprecationWarning)
def test_pipeline_iris(self):
iris = load_iris()
X, y = iris.data, iris.target
Expand All @@ -104,6 +106,7 @@ def test_pipeline_iris(self):
shapes = set(shapes)
self.assertEqual(shapes, {(150, 3), (150, 4), (150, 2), (150,)})

@ignore_warnings(DeprecationWarning)
@skipif_appveyor("crashes")
def test_pipeline_iris_change_dim(self):
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument # pylint: disable=E0611
Expand All @@ -121,6 +124,7 @@ def test_pipeline_iris_change_dim(self):
self.assertEqual(len(y.shape), 2)
self.assertEqual(y.shape[0], 2)

@ignore_warnings(DeprecationWarning)
def test_pipeline_iris_intermediate(self):
iris = load_iris()
X, y = iris.data, iris.target
Expand Down
6 changes: 5 additions & 1 deletion mlprodict/onnx_grammar/onnx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _post_process(self, op, node):
continue
o, v = node['args'][i]
if (o == 'UnaryOp' and len(v['args']) == 1 and
isinstance(v['args'][0], (int, float, str, numpy.int64,
isinstance(v['args'][0], (int, float, numpy.int64,
numpy.float32, numpy.float64))):
if v['op'] == 'Sub':
node['args'][i] = -v['args'][0]
Expand Down Expand Up @@ -555,6 +555,7 @@ def depart(self, node, info):
"\n{}".format(
child['type'], pprint.pformat(info)))
return

if kind == "Name":
op, buf = self._get_last(
('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword',
Expand All @@ -572,6 +573,9 @@ def depart(self, node, info):
elif op == 'keyword':
buf['value'] = info['str']
return
elif op == 'UnaryOp':
buf['args'].append(info['str'])
return
elif op == 'FunctionDef':
raise RuntimeError("Default value must be constant, variable '{}' was "
"detected.".format(info['str']))
Expand Down
2 changes: 1 addition & 1 deletion mlprodict/sklapi/onnx_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _check_arrays(self, inputs):
if not hasattr(self, "onnxrt_"):
continue
exp = sht[i]
if exp[1][1:] != v.shape[1:]:
if exp[1] != ('?', ) and exp[1][1:] != v.shape[1:]:
raise RuntimeError( # pragma: no cover
"Unexpected shape for input '{}': {} != {} "
"(expected).".format(
Expand Down

0 comments on commit eea3ba4

Please sign in to comment.