diff --git a/mlprodict/onnxrt/ops_cpu/op_cum_sum.py b/mlprodict/onnxrt/ops_cpu/op_cum_sum.py index ab8532dee..53e8013da 100644 --- a/mlprodict/onnxrt/ops_cpu/op_cum_sum.py +++ b/mlprodict/onnxrt/ops_cpu/op_cum_sum.py @@ -26,19 +26,23 @@ def _run(self, x, *axis): # pylint: disable=W0221 'reverse=1 or exclusive=1 not implemented') if self.inplaces.get(0, False): return (numpy.cumsum(x, out=x), ) - else: - return (numpy.cumsum(x), ) - if len(axis.shape) != 1 or axis.shape[0] != 1: - raise RuntimeError( - "axis must be an array of one number not {}".format(axis)) - axis = axis[0] + return (numpy.cumsum(x), ) + if isinstance(axis, (numpy.int32, numpy.int64)): + pass + else: + if (len(axis.shape) > 1 or (len(axis.shape) > 0 and + axis.shape[0] != 1)): + raise RuntimeError( + "axis must be an array of one number not {} " + "(shape {})".format(axis, axis.shape)) + if len(axis.shape) > 0: + axis = axis[0] if self.reverse or self.exclusive: raise NotImplementedError( 'reverse=1 or exclusive=1 not implemented') if self.inplaces.get(0, False): return (numpy.cumsum(x, axis=axis, out=x), ) - else: - return (numpy.cumsum(x, axis=axis), ) + return (numpy.cumsum(x, axis=axis), ) def _infer_shapes(self, x, *axis): # pylint: disable=W0221 return (x, )