forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
WrapDim.py
42 lines (31 loc) · 1.34 KB
/
WrapDim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from . import CWrapPlugin
from string import Template
class WrapDim(CWrapPlugin):
NDIM_TEMPLATE = Template(
"""${arg_tensor}->nDimension""")
CODE_TEMPLATE = Template(
"""THPUtils_assert(${ndim} > 0,
"dimension specified as %d, but tensor has no dimensions", ${arg_dim});
THPUtils_assert(${arg_dim} >= -(${ndim}) && ${arg_dim} < (${ndim}),
"dimension out of range (expected to be in range of [%d, %d], but got %d)",
-(${ndim}), (${ndim})-1, ${arg_dim});
if (${arg_dim} < 0) ${arg_dim} += (${ndim});""")
def initialize(self, cwrap):
self.cwrap = cwrap
def process_option_code_template(self, template, option):
new_code = []
for i, arg in enumerate(option['arguments']):
if 'wrap_dim' not in arg:
continue
params = arg.get('wrap_dim').split("+")
arg_tensor = params[0]
arg_tensor = "arg_" + arg_tensor
arg_dim = "arg_" + arg.get('assign_name', arg['name'])
params[0] = self.NDIM_TEMPLATE.substitute(arg_tensor=arg_tensor)
ndim = "+".join(params)
new_code.append(self.CODE_TEMPLATE.substitute(
arg_dim=arg_dim,
ndim=ndim))
new_code.append("")
template = new_code + template
return template