diff --git a/scripts/gen.py b/scripts/gen.py index 2eeb9541f89c..f7c06abf14d0 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -24,7 +24,7 @@ def namedtuple_with_defaults(typename, field_names, default_values=()): 'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig' ) -FuncOpts = namedtuple_with_defaults('FuncOpts', 'ref_param') +FuncOpts = namedtuple_with_defaults('FuncOpts', 'ref_param, device_param') _GRAMMAR = r""" start: type fnname "(" params ")" @@ -259,6 +259,8 @@ class {type_name} : public AtenXlaType {{ } _FUNCTION_OPTIONS = { + 'copy(Tensor, bool, optional) -> Tensor': + FuncOpts(device_param='*to_device'), 'to(Tensor, TensorOptions, bool, bool) -> Tensor': FuncOpts(ref_param='options'), 'to(Tensor, Device, ScalarType, bool, bool) -> Tensor': @@ -497,7 +499,7 @@ def create_map_sig(tree, orig_sig): def emit_fn(t): if isinstance(t, lark.lexer.Token): return -1 if t.type in ['CONST', 'REF', 'PTR'] else 0 - return -1 if t.data == 'param_name' else 0 + return -1 if t.data in ['param_name', 'param_defval'] else 0 emit = StringEmit(orig_sig) # Emit full function return type. @@ -605,7 +607,13 @@ def param_type(t): return c -def get_return_value(rtype, rname, param, var, ref_param): +def get_optional(fnopts, name, defval=None): + if fnopts is None or not hasattr(fnopts, name): + return defval + return getattr(fnopts, name, defval) or defval + + +def get_return_value(rtype, rname, param, var, ref_param, fnopts): crtype = type_core(rtype) if type_is_const(rtype) or type_is_refptr(rtype, '&'): # If the return type is a const or a reference, return the matching @@ -620,7 +628,7 @@ def get_return_value(rtype, rname, param, var, ref_param): # wrapping the proper local variable which has been created by calling # into the CPU tensor implementation. return 'bridge::CreateXlaTensor({}, bridge::GetXlaDevice({}))'.format( - rname, param_name(ref_param)) + rname, get_optional(fnopts, 'device_param', param_name(ref_param))) def get_reference_param(params, fnopts=None): @@ -632,7 +640,7 @@ def get_reference_param(params, fnopts=None): ptype = param_type(p) cptype = type_core(ptype) pname = param_name(p) - if fnopts and fnopts.ref_param == pname: + if get_optional(fnopts, 'ref_param') == pname: return p if not other and (cptype == 'TensorOptions' or cptype == 'TensorList'): other = p @@ -644,7 +652,8 @@ def get_reference_param(params, fnopts=None): return ref_param or other -def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param): +def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param, + fnopts): types = tuple_type_list(rtype) retstr = '{}('.format(rtype_str) for i, ttype in enumerate(types): @@ -652,7 +661,7 @@ def get_tuple_return(rtype, rtype_str, rname, params, param_vars, ref_param): retstr += ', ' tuple_var = 'std::get<{}>({})'.format(i, rname) retstr += get_return_value(ttype, tuple_var, list_get(params, i), - list_get(param_vars, i), ref_param) + list_get(param_vars, i), ref_param, fnopts) return retstr + ')' @@ -690,18 +699,19 @@ def generate_exit_debug_code(t, fname, rname, params, param_vars, ref_param): def generate_return_stmt(t, rtype_str, fname, rname, params, param_vars, - ref_param): + ref_param, fnopts): assert isinstance(t, lark.tree.Tree) rtype = t.children[0] ctype = type_core(rtype) if ctype == 'std::tuple': retstr = get_tuple_return(rtype, rtype_str, rname, params, param_vars, - ref_param) + ref_param, fnopts) elif ctype == 'std::vector': retstr = 'bridge::CreateXlaTensors({}, bridge::GetXlaDevice({}))'.format( - rname, param_name(ref_param)) + rname, get_optional(fnopts, 'device_param', param_name(ref_param))) elif ctype == 'Tensor': - retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param) + retstr = get_return_value(rtype, rname, params[0], param_vars[0], ref_param, + fnopts) elif ctype == 'void' and not type_is_refptr(rtype, '*'): return '' else: @@ -785,7 +795,7 @@ def gen_fnname(x): else: xname = tfetcher.add(pname, True) param_vars.append(xname) - if p == ref_param and not (fnopts and fnopts.ref_param): + if p == ref_param and not get_optional(fnopts, 'ref_param'): xla_ref_param = param_vars[-1] code += tfetcher.generate() result_assign = generate_result_assignment(tree, _RESULT_NAME) @@ -804,7 +814,7 @@ def gen_fnname(x): params, param_vars, ref_param) code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname, _RESULT_NAME if result_assign else None, params, - param_vars, ref_param) + param_vars, ref_param, fnopts) code += '}' return FuncGen( tree=tree,