From 96869445724363b31e47bb472f53f29c65261929 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Thu, 21 Mar 2019 10:01:37 -0700 Subject: [PATCH] Add support to auto-generate *_out() functions. --- scripts/gen.py | 167 ++++++++++++++++++++++++------- torch_xla/csrc/aten_xla_type.cpp | 6 -- torch_xla/csrc/aten_xla_type.h | 3 - 3 files changed, 131 insertions(+), 45 deletions(-) diff --git a/scripts/gen.py b/scripts/gen.py index 525c1160719b..4796b71b5375 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -5,6 +5,7 @@ import lark import os import re +import string import sys @@ -19,12 +20,17 @@ def namedtuple_with_defaults(typename, field_names, default_values=()): return ntuple +class ArgTemplate(string.Template): + idpattern = r'[a-z0-9_]+' + + FuncGen = namedtuple_with_defaults( 'FuncGen', 'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig' ) -FuncOpts = namedtuple_with_defaults('FuncOpts', 'ref_param, device_param') +FuncOpts = namedtuple_with_defaults('FuncOpts', + 'ref_param, device_param, outfn_template') _GRAMMAR = r""" start: type fnname "(" params ")" @@ -71,14 +77,14 @@ def namedtuple_with_defaults(typename, field_names, default_values=()): _FN_BLACKLIST = set([ # ATEN functions - 'toBackend', - 'toScalarType', 'backward', 'set_data', - 'tensorFromBlob', - 'tensorWithAllocator', 'storageFromBlob', 'storageWithAllocator', + 'tensorFromBlob', + 'tensorWithAllocator', + 'toBackend', + 'toScalarType', 'unsafeStorageFromTH', 'unsafeTensorFromTH', # XLA/TPU functions @@ -86,8 +92,8 @@ def namedtuple_with_defaults(typename, field_names, default_values=()): 'numel', 'ones', 'ones_like', - 'zeros', 'zero_', + 'zeros', 'zeros_like', ]) @@ -97,6 +103,20 @@ def namedtuple_with_defaults(typename, field_names, default_values=()): # XLA/TPU functions ] +_FN_OUT = { + 'arange_out(Tensor, Scalar, Scalar, Scalar) -> Tensor': + FuncOpts( + outfn_template=ArgTemplate('arange($1, $2, $3, $0.options())')), + 'kthvalue_out': + FuncOpts(), + 'log_out': + FuncOpts(), +} + +# List of tuples with the regex match first, and the corresponding FuncOpts() +# second. +_FN_OUT_REGEX = [] + _TYPE_NSMAP = { 'Tensor': 'at::Tensor', 'TensorList': 'at::TensorList', @@ -352,6 +372,16 @@ def is_blacklisted_fn(fname, mapsig): return False +def get_outfn_options(fname, mapsig): + for name in [fname, mapsig]: + fnopts = _FN_OUT.get(name, None) + if fnopts is not None: + return fnopts + for frx, fnopts in _FN_OUT_REGEX: + if re.match(frx, fname) or re.match(frx, mapsig): + return fnopts + + def create_type_instances(): code = '' code += _CLASS_INST_HEADER.format( @@ -665,7 +695,7 @@ def get_return_type_str(t, orig_sig): return orig_sig[0:token.column - 2] -def generate_entry_debug_code(t, fname, params, ref_param): +def generate_entry_debug_code(t, fname, params): # Emits debug code for a given intercepted ATEN type function. For now we use # a counter which will show up in the metrics reports. code = '' @@ -683,7 +713,7 @@ def generate_entry_debug_code(t, fname, params, ref_param): return code -def generate_exit_debug_code(t, fname, rname, params, param_vars, ref_param): +def generate_exit_debug_code(t, fname, rname, params, param_vars): code = '' return code @@ -732,35 +762,66 @@ def rewrite_tensor_options(fname, pname): return code, xname -def get_xla_wrapper(orig_sig, ctx): - tree = _PARSER.parse(orig_sig) - xtree = _XPARSER.parse(orig_sig) - mapsig = create_map_sig(xtree, orig_sig) - rwsig = rewrite_signature(orig_sig, _TYPE_NSMAP) - rwxtree = _XPARSER.parse(rwsig) - params = get_parameters(tree) - fnopts = _FUNCTION_OPTIONS.get(mapsig, None) - ref_param = get_reference_param(params, fnopts=fnopts) +def generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts): + rtype = tree.children[0] + num_outputs = None + if type_core(rtype) == 'std::tuple': + num_outputs = len(tuple_type_list(rtype)) - # There are a few functions with the same function name but different - # parameter list. Generate a unique XL function name here. - def gen_fnname(x): - if ctx.gen_class_mode: - return 'AtenXlaTypeBase::{}'.format(x) - post = '' - if x in ctx.defdb: - post = '_{}'.format(ctx.defdb[x]) - ctx.defdb[x] += 1 - else: - ctx.defdb[x] = 1 - return 'xla_' + x + post + code = '{} {}{{\n'.format(sig, 'const ' if ctx.gen_class_mode else '') + code += generate_entry_debug_code(tree, fname, params) - sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname) - if is_blacklisted_fn(fname, mapsig): - return None + param_vars = [] + for p in params: + pname = param_name(p) + param_vars.append(pname) + + if fnopts.outfn_template is not None: + mdict = {} + for i, pname in enumerate(param_vars): + mdict[str(i)] = pname + fcall = fnopts.outfn_template.substitute(mdict) + else: + m = re.match(r'(.*)_out$', fname) + assert m is not None, fname + core_fname = m.group(1) + out_count = num_outputs if num_outputs is not None else 1 + fcall = '{}('.format(core_fname) + for i in range(out_count, len(param_vars)): + if i > out_count: + fcall += ', ' + fcall += param_vars[i] + fcall += ')' + + if num_outputs is None: + code += ' {} = {};\n'.format(param_vars[0], fcall) + code += generate_exit_debug_code(tree, fname, param_vars[0], params, + param_vars) + code += ' return {};\n'.format(param_vars[0]) + else: + code += ' std::tie(' + for i in range(0, num_outputs): + if i > 0: + code += ', ' + code += param_vars[i] + code += ') = {};\n'.format(fcall) + code += generate_exit_debug_code(tree, fname, param_vars[0:num_outputs], + params, param_vars) + code += ' return {}('.format(get_return_type_str(rwxtree, rwsig)) + for i in range(0, num_outputs): + if i > 0: + code += ', ' + code += param_vars[i] + code += ');\n' + code += '}' + return code + + +def generate_aten_to_xla(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts): + ref_param = get_reference_param(params, fnopts=fnopts) code = '{} {}{{\n'.format(sig, 'const ' if ctx.gen_class_mode else '') - code += generate_entry_debug_code(tree, fname, params, ref_param) + code += generate_entry_debug_code(tree, fname, params) xla_ref_param = param_name(ref_param) if ref_param else None tfetcher = TensorFetcher('xlatens') param_vars = [] @@ -796,13 +857,47 @@ def gen_fnname(x): if result_assign: code += (' static_cast({}); // Avoid warnings in case not ' 'used\n'.format(_RESULT_NAME)) - code += generate_exit_debug_code(tree, fname, - _RESULT_NAME if result_assign else None, - params, param_vars, ref_param) + code += generate_exit_debug_code( + tree, fname, _RESULT_NAME if result_assign else None, params, param_vars) code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname, _RESULT_NAME if result_assign else None, params, param_vars, ref_param, fnopts) code += '}' + return code + + +def get_xla_wrapper(orig_sig, ctx): + tree = _PARSER.parse(orig_sig) + xtree = _XPARSER.parse(orig_sig) + mapsig = create_map_sig(xtree, orig_sig) + rwsig = rewrite_signature(orig_sig, _TYPE_NSMAP) + rwxtree = _XPARSER.parse(rwsig) + params = get_parameters(tree) + fnopts = _FUNCTION_OPTIONS.get(mapsig, None) + + # There are a few functions with the same function name but different + # parameter list. Generate a unique XL function name here. + def gen_fnname(x): + if ctx.gen_class_mode: + return 'AtenXlaTypeBase::{}'.format(x) + post = '' + if x in ctx.defdb: + post = '_{}'.format(ctx.defdb[x]) + ctx.defdb[x] += 1 + else: + ctx.defdb[x] = 1 + return 'xla_' + x + post + + sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname) + if is_blacklisted_fn(fname, mapsig): + return None + ofnopts = get_outfn_options(fname, mapsig) + if ofnopts is not None: + code = generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, + ofnopts) + else: + code = generate_aten_to_xla(ctx, tree, rwxtree, fname, sig, rwsig, params, + fnopts) return FuncGen( tree=tree, xtree=xtree, diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 8417a62a39e5..d702a5031bdc 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -444,12 +444,6 @@ at::Tensor AtenXlaType::arange(at::Scalar start, at::Scalar end, xla_options.get_scalar_type())); } -at::Tensor& AtenXlaType::arange_out(at::Tensor& out, at::Scalar start, - at::Scalar end, at::Scalar step) const { - out = arange(start, end, step, out.options()); - return out; -} - at::Tensor AtenXlaType::argmax(const at::Tensor& self, c10::optional dim, bool keepdim) const { return dim ? bridge::AtenFromXlaTensor( diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 41752f65f728..1e9096da1e1a 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -168,9 +168,6 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor arange(at::Scalar start, at::Scalar end, at::Scalar step, const at::TensorOptions& options) const override; - at::Tensor& arange_out(at::Tensor& out, at::Scalar start, at::Scalar end, - at::Scalar step) const override; - at::Tensor argmax(const at::Tensor& self, c10::optional dim, bool keepdim) const override;