Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def namedtuple_with_defaults(typename, field_names, default_values=()):
'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig'
)

FuncOpts = namedtuple_with_defaults('FuncOpts', 'ref_param')
FuncOpts = namedtuple_with_defaults('FuncOpts', 'ref_param, device_param')

_GRAMMAR = r"""
start: type fnname "(" params ")"
Expand Down Expand Up @@ -259,6 +259,8 @@ class {type_name} : public AtenXlaType {{
}

_FUNCTION_OPTIONS = {
'copy(Tensor, bool, optional<Device>) -> Tensor':
FuncOpts(device_param='*to_device'),
'to(Tensor, TensorOptions, bool, bool) -> Tensor':
FuncOpts(ref_param='options'),
'to(Tensor, Device, ScalarType, bool, bool) -> Tensor':
Expand Down Expand Up @@ -497,7 +499,7 @@ def create_map_sig(tree, orig_sig):
def emit_fn(t):
if isinstance(t, lark.lexer.Token):
return -1 if t.type in ['CONST', 'REF', 'PTR'] else 0
return -1 if t.data == 'param_name' else 0
return -1 if t.data in ['param_name', 'param_defval'] else 0

emit = StringEmit(orig_sig)
# Emit full function return type.
Expand Down Expand Up @@ -605,7 +607,13 @@ def param_type(t):
return c


def get_return_value(rtype, rname, param, var, ref_param):
def get_optional(fnopts, name, defval=None):
if fnopts is None or not hasattr(fnopts, name):
return defval
return getattr(fnopts, name, defval) or defval


def get_return_value(rtype, rname, param, var, ref_param, fnopts):
crtype = type_core(rtype)
if type_is_const(rtype) or type_is_refptr(rtype, '&'):
# If the return type is a const or a reference, return the matching
Expand All @@ -620,7 +628,7 @@ def get_return_value(rtype, rname, param, var, ref_param):
# wrapping the proper local variable which has been created by calling
# into the CPU tensor implementation.
return 'bridge::CreateXlaTensor({}, bridge::GetXlaDevice({}))'.format(
rname, param_name(ref_param))
rname, get_optional(fnopts, 'device_param', param_name(ref_param)))


def get_reference_param(params, fnopts=None):
Expand All @@ -632,7 +640,7 @@ def get_reference_param(params, fnopts=None):
ptype = param_type(p)
cptype = type_core(ptype)
pname = param_name(p)
if fnopts and fnopts.ref_param == pname:
if get_optional(fnopts, 'ref_param') == pname:
return p
if not other and (cptype == 'TensorOptions' or cptype == 'TensorList'):
other = p
Expand All @@ -644,15 +652,16 @@ def get_reference_param(params, fnopts=None):
return ref_param or other


def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param):
def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param,
fnopts):
types = tuple_type_list(rtype)
retstr = '{}('.format(rtype_str)
for i, ttype in enumerate(types):
if i > 0:
retstr += ', '
tuple_var = 'std::get<{}>({})'.format(i, rname)
retstr += get_return_value(ttype, tuple_var, list_get(params, i),
list_get(param_vars, i), ref_param)
list_get(param_vars, i), ref_param, fnopts)
return retstr + ')'


Expand Down Expand Up @@ -690,18 +699,19 @@ def generate_exit_debug_code(t, fname, rname, params, param_vars, ref_param):


def generate_return_stmt(t, rtype_str, fname, rname, params, param_vars,
ref_param):
ref_param, fnopts):
assert isinstance(t, lark.tree.Tree)
rtype = t.children[0]
ctype = type_core(rtype)
if ctype == 'std::tuple':
retstr = get_tuple_return(rtype, rtype_str, rname, params, param_vars,
ref_param)
ref_param, fnopts)
elif ctype == 'std::vector':
retstr = 'bridge::CreateXlaTensors({}, bridge::GetXlaDevice({}))'.format(
rname, param_name(ref_param))
rname, get_optional(fnopts, 'device_param', param_name(ref_param)))
elif ctype == 'Tensor':
retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param)
retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param,
fnopts)
elif ctype == 'void' and not type_is_refptr(rtype, '*'):
return ''
else:
Expand Down Expand Up @@ -785,7 +795,7 @@ def gen_fnname(x):
else:
xname = tfetcher.add(pname, True)
param_vars.append(xname)
if p == ref_param and not (fnopts and fnopts.ref_param):
if p == ref_param and not get_optional(fnopts, 'ref_param'):
xla_ref_param = param_vars[-1]
code += tfetcher.generate()
result_assign = generate_result_assignment(tree, _RESULT_NAME)
Expand All @@ -804,7 +814,7 @@ def gen_fnname(x):
params, param_vars, ref_param)
code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname,
_RESULT_NAME if result_assign else None, params,
param_vars, ref_param)
param_vars, ref_param, fnopts)
code += '}'
return FuncGen(
tree=tree,
Expand Down