This repository has been archived by the owner on Jan 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
op_split.py
82 lines (64 loc) · 2.2 KB
/
op_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
@file
@brief Runtime operator.
"""
from onnx.defs import onnx_opset_version
from ._op import OpRun
class CommonSplit(OpRun):
"""
Runtime for operator *Split*.
"""
def __init__(self, onnx_node, desc=None,
expected_attributes=None, **options):
if 'split' not in options:
options['split'] = None
OpRun.__init__(self, onnx_node, desc=desc,
expected_attributes=expected_attributes,
**options)
self.nb_outputs = len(onnx_node.output)
def common_run(self, mat, split): # pylint: disable=W0221
if split is None:
div = mat.shape[self.axis] // self.nb_outputs
split = [div] * self.nb_outputs
split[-1] += mat.shape[self.axis] - sum(split)
sli = [slice(0, s) for s in mat.shape]
res = []
pos = 0
for spl in split:
sli[self.axis] = slice(pos, pos + spl)
pos += spl
res.append(mat[tuple(sli)])
return tuple(res)
class Split_2(CommonSplit):
"""
Runtime for operator *Split*.
"""
atts = {'axis': 0, 'split': None}
def __init__(self, onnx_node, desc=None, **options):
CommonSplit.__init__(self, onnx_node, desc=desc,
expected_attributes=Split_2.atts, **options)
def _run(self, mat, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
return self.common_run(mat, self.split)
class Split_11(Split_2):
"""
Runtime for operator *Split*.
"""
pass
class Split_13(CommonSplit):
"""
Runtime for operator *Split*.
"""
atts = {'axis': 0}
def __init__(self, onnx_node, desc=None, **options):
CommonSplit.__init__(self, onnx_node, desc=desc,
expected_attributes=Split_13.atts, **options)
def _run(self, mat, split=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
return self.common_run(mat, split)
if onnx_opset_version() >= 13:
Split = Split_13
elif onnx_opset_version() >= 11: # pragma: no cover
Split = Split_11
else: # pragma: no cover
Split = Split_2