This repository has been archived by the owner on Jan 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
op_cum_sum.py
56 lines (49 loc) · 1.97 KB
/
op_cum_sum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
@file
@brief Runtime operator.
"""
import numpy
from ._op import OpRun
class CumSum(OpRun):
atts = {'exclusive': 0, 'reverse': 0}
python_inputs = ['x', 'axis=None']
def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc,
expected_attributes=CumSum.atts,
**options)
def _run(self, x, *axis): # pylint: disable=W0221
axis = None if len(axis) == 0 else axis[0]
if axis is None:
if self.reverse or self.exclusive:
raise NotImplementedError(
'reverse=1 or exclusive=1 not implemented')
if self.inplaces.get(0, False):
return (numpy.cumsum(x, out=x), )
return (numpy.cumsum(x), )
if isinstance(axis, (numpy.int32, numpy.int64)):
pass
else:
if (len(axis.shape) > 1 or (len(axis.shape) > 0 and
axis.shape[0] != 1)):
raise RuntimeError(
"axis must be an array of one number not {} "
"(shape {})".format(axis, axis.shape))
if len(axis.shape) > 0:
axis = axis[0]
if self.reverse or self.exclusive:
raise NotImplementedError(
'reverse=1 or exclusive=1 not implemented')
if self.inplaces.get(0, False):
return (numpy.cumsum(x, axis=axis, out=x), )
return (numpy.cumsum(x, axis=axis), )
def _infer_shapes(self, x, *axis): # pylint: disable=W0221
return (x, )
def to_python(self, inputs):
lines = ['if exclusive or reverse:',
' raise NotImplementedError("reverse=1 or exclusive=1 not implemente")',
'if axis is None:',
' return numpy.cumsum(x)',
'return numpy.cumsum(x, axis=axis[0])']
return 'import numpy', "\n".join(lines)