Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _unittests/ut_testing/test_einsum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
@brief test log(time=30s)
@brief test log(time=8s)
"""
import unittest
import io
Expand Down
25 changes: 25 additions & 0 deletions _unittests/ut_testing/test_einsum_bug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
@brief test log(time=3s)
"""
import unittest
from pyquickhelper.pycode import ExtTestCase
from mlprodict.testing.einsum import decompose_einsum_equation


class TestEinsumBug(ExtTestCase):

def test_abbba(self):
res = decompose_einsum_equation(
"ab,b->ba", strategy='numpy', clean=True)
self.assertNotEmpty(res)

def test__pprint_forward(self):
res = decompose_einsum_equation(
"ab,b->ba", strategy='numpy', clean=True)
pf = res._pprint_forward() # pylint: disable=W0212
spl = pf.split("<- id")
self.assertEqual(len(spl), 4)


if __name__ == "__main__":
unittest.main()
35 changes: 32 additions & 3 deletions _unittests/ut_testing/test_einsum_einsum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
@brief test log(time=12s)
@brief test log(time=15s)
"""
import unittest
import numpy
Expand All @@ -12,7 +12,8 @@
class TestEinsumEinsum(ExtTestCase):

def common_test(self, equation, runtime=None, opset=None, N=5,
optimize=False, decompose=True, strategy=None):
optimize=False, decompose=True, strategy=None,
double=True):
if opset is None:
opset = get_opset_number_from_onnx()
inps = equation.split('->')[0].split(',')
Expand All @@ -28,6 +29,8 @@ def common_test(self, equation, runtime=None, opset=None, N=5,
runtime = [runtime]
for rt in runtime:
for dtype in [numpy.float32, numpy.float64]:
if not double and dtype == numpy.float64:
continue
decimal = 5 if dtype == numpy.float32 else 8
with self.subTest(dt=dtype, rt=rt,
eq=equation, opset=opset,
Expand Down Expand Up @@ -57,10 +60,36 @@ def test_einsum_optimize(self):
def test_einsum_optimize_ml(self):
self.common_test("abc,cd->abd", optimize=True, strategy='ml')

def test_einsum_optimize_ml_merge(self):
self.common_test("abce,cd->abd", optimize=True, strategy='ml')

def test_einsum_optimize_ml_reduceprod(self):
self.common_test("ab,ab->ab", optimize=True, strategy='ml',
double=False)

def test_einsum_optimize_ml_mul(self):
self.common_test("ab,b->ab", optimize=True,
strategy='ml', double=False)
self.common_test("ab,b->a", optimize=True, strategy='ml')
self.common_test("ab,a->a", optimize=True, strategy='ml', double=False)
self.common_test("ab,b->b", optimize=True, strategy='ml', double=False)
self.common_test("ab,a->b", optimize=True, strategy='ml')

def test_einsum_optimize_ml_mul2(self):
self.common_test("ba,b->ba", optimize=False, double=False)

def test_einsum_optimize_no(self):
self.common_test("abc,cd->abd", optimize=True, decompose=False)

def test_einsum_optimize_ml_cases(self):
self.common_test("ab,cd->abcd", optimize=True, strategy='ml')
# self.common_test("ab,cd,ef->acdf", optimize=True, strategy='ml')
# self.common_test("ab,cd,de->abcde", optimize=True, strategy='ml')
# self.common_test("ab,cd,de->be", optimize=True, strategy='ml')
# self.common_test("ab,bcd,cd->abcd", optimize=True, strategy='ml')
# self.common_test("ab,bcd,cd->abd", optimize=True, strategy='ml')


if __name__ == "__main__":
# TestEinsumEinsum().test_einsum_optimize_ml()
# TestEinsumEinsum().test_einsum_optimize_ml_mul2()
unittest.main()
2 changes: 1 addition & 1 deletion mlprodict/testing/einsum/einsum_fct.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _build_optimize_ml(self):
transposes.append(
[shape, list(node.attribute[0].ints)])

delta = sum(predict_transposition_cost(*v)
delta = sum(max(0, predict_transposition_cost(*v))
for v in transposes)

confs.append((delta, eq))
Expand Down
6 changes: 5 additions & 1 deletion mlprodict/testing/einsum/einsum_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,15 @@ def decompose_einsum_equation(equation, *shapes, strategy="simple",
raise ValueError("Unknown strategy %r." % strategy)

# Last step: clean unused nodes.
graph.mark_last_node()
if clean:
last_node = graph.last_added_op
graph.append(EinsumSubOp(last_node.full_dim, 'id', last_node))
graph.mark_last_node()
graph.simplify_mm_nodes(verbose=verbose)
graph.remove_duplicate_transpose(verbose=verbose)
graph.clean_unused_nodes(verbose=verbose)
else:
graph.mark_last_node()
return graph


Expand Down
14 changes: 13 additions & 1 deletion mlprodict/testing/einsum/einsum_impl_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,15 @@ def _get_forward_nodes(self):
forward[key] = [op]
return forward

def _pprint_forward(self):
rows = []
for op in self:
line = "%r <- %s(%s)" % (
id(op), op.name,
", ".join(map(str, [id(_) for _ in op.inputs])))
rows.append(line)
return "\n".join(rows)

def _replace_node_sequence(self, added, deleted):
"""
Removes a sequence of nodes. The method does not check
Expand All @@ -1284,7 +1293,10 @@ def _replace_node_sequence(self, added, deleted):
key = id(deleted[-1])
if key not in forward:
raise RuntimeError( # pragma: no cover
"key %r missing in all forward nodes." % key)
"Key {} missing in all forward nodes (other keys {}), "
"all keys:\n{}".format(
key, [id(_) for _ in deleted],
self._pprint_forward()))

# deletion
mark_input = None
Expand Down
33 changes: 33 additions & 0 deletions mlprodict/tools/onnx_micro_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,39 @@ def _op_matmul(self, x, y):
"Runtime for operator :epkg:`Op:MatMul`."
return (numpy.matmul(x, y), )

def _op_max(self, *inps):
"Runtime for operator :epkg:`Op:Max`."
return (numpy.maximum(*inps), )

def _op_mul(self, x, y):
"Runtime for operator :epkg:`Op:Mul`."
return (x * y, )

def _op_reduceprod(self, data, axes=None, keepdims=None):
"Runtime for operator :epkg:`Op:ReduceProd`."
if axes is not None and not isinstance(axes, int):
if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
axes = int(axes)
else:
axes = tuple(axes) if len(axes) > 0 else None
return (numpy.prod(data, axis=axes,
keepdims=keepdims,
dtype=data.dtype), )

def _op_reducesum(self, data, axes, keepdims=None,
noop_with_empty_axes=None):
"Runtime for operator :epkg:`Op:ReduceSum`."
if axes is None and noop_with_empty_axes:
return (data, )
if axes is not None and not isinstance(axes, int):
if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
axes = int(axes)
else:
axes = tuple(axes) if len(axes) > 0 else None
return (numpy.sum(data, axis=axes,
keepdims=keepdims,
dtype=data.dtype), )

def _op_reshape(self, x, shape):
"Runtime for operator :epkg:`Op:Reshape`."
return (x.reshape(shape), )
Expand Down