Skip to content
This repository has been archived by the owner on Jan 13, 2024. It is now read-only.

Commit

Permalink
fix bug in runtime for operator cumsum
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Jun 26, 2020
1 parent e4bb085 commit 45a4e80
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions mlprodict/onnxrt/ops_cpu/op_cum_sum.py
Expand Up @@ -26,19 +26,23 @@ def _run(self, x, *axis): # pylint: disable=W0221
'reverse=1 or exclusive=1 not implemented')
if self.inplaces.get(0, False):
return (numpy.cumsum(x, out=x), )
else:
return (numpy.cumsum(x), )
if len(axis.shape) != 1 or axis.shape[0] != 1:
raise RuntimeError(
"axis must be an array of one number not {}".format(axis))
axis = axis[0]
return (numpy.cumsum(x), )
if isinstance(axis, (numpy.int32, numpy.int64)):
pass
else:
if (len(axis.shape) > 1 or (len(axis.shape) > 0 and
axis.shape[0] != 1)):
raise RuntimeError(
"axis must be an array of one number not {} "
"(shape {})".format(axis, axis.shape))
if len(axis.shape) > 0:
axis = axis[0]
if self.reverse or self.exclusive:
raise NotImplementedError(
'reverse=1 or exclusive=1 not implemented')
if self.inplaces.get(0, False):
return (numpy.cumsum(x, axis=axis, out=x), )
else:
return (numpy.cumsum(x, axis=axis), )
return (numpy.cumsum(x, axis=axis), )

def _infer_shapes(self, x, *axis): # pylint: disable=W0221
return (x, )
Expand Down

0 comments on commit 45a4e80

Please sign in to comment.