Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 131 additions & 36 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import lark
import os
import re
import string
import sys


Expand All @@ -19,12 +20,17 @@ def namedtuple_with_defaults(typename, field_names, default_values=()):
return ntuple


class ArgTemplate(string.Template):
idpattern = r'[a-z0-9_]+'


FuncGen = namedtuple_with_defaults(
'FuncGen',
'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig'
)

FuncOpts = namedtuple_with_defaults('FuncOpts', 'ref_param, device_param')
FuncOpts = namedtuple_with_defaults('FuncOpts',
'ref_param, device_param, outfn_template')

_GRAMMAR = r"""
start: type fnname "(" params ")"
Expand Down Expand Up @@ -71,23 +77,23 @@ def namedtuple_with_defaults(typename, field_names, default_values=()):

_FN_BLACKLIST = set([
# ATEN functions
'toBackend',
'toScalarType',
'backward',
'set_data',
'tensorFromBlob',
'tensorWithAllocator',
'storageFromBlob',
'storageWithAllocator',
'tensorFromBlob',
'tensorWithAllocator',
'toBackend',
'toScalarType',
'unsafeStorageFromTH',
'unsafeTensorFromTH',
# XLA/TPU functions
'_s_copy_from',
'numel',
'ones',
'ones_like',
'zeros',
'zero_',
'zeros',
'zeros_like',
])

Expand All @@ -97,6 +103,20 @@ def namedtuple_with_defaults(typename, field_names, default_values=()):
# XLA/TPU functions
]

_FN_OUT = {
'arange_out(Tensor, Scalar, Scalar, Scalar) -> Tensor':
FuncOpts(
outfn_template=ArgTemplate('arange($1, $2, $3, $0.options())')),
'kthvalue_out':
FuncOpts(),
'log_out':
FuncOpts(),
}

# List of tuples with the regex match first, and the corresponding FuncOpts()
# second.
_FN_OUT_REGEX = []

_TYPE_NSMAP = {
'Tensor': 'at::Tensor',
'TensorList': 'at::TensorList',
Expand Down Expand Up @@ -352,6 +372,16 @@ def is_blacklisted_fn(fname, mapsig):
return False


def get_outfn_options(fname, mapsig):
for name in [fname, mapsig]:
fnopts = _FN_OUT.get(name, None)
if fnopts is not None:
return fnopts
for frx, fnopts in _FN_OUT_REGEX:
if re.match(frx, fname) or re.match(frx, mapsig):
return fnopts


def create_type_instances():
code = ''
code += _CLASS_INST_HEADER.format(
Expand Down Expand Up @@ -665,7 +695,7 @@ def get_return_type_str(t, orig_sig):
return orig_sig[0:token.column - 2]


def generate_entry_debug_code(t, fname, params, ref_param):
def generate_entry_debug_code(t, fname, params):
# Emits debug code for a given intercepted ATEN type function. For now we use
# a counter which will show up in the metrics reports.
code = ''
Expand All @@ -683,7 +713,7 @@ def generate_entry_debug_code(t, fname, params, ref_param):
return code


def generate_exit_debug_code(t, fname, rname, params, param_vars, ref_param):
def generate_exit_debug_code(t, fname, rname, params, param_vars):
code = ''
return code

Expand Down Expand Up @@ -732,35 +762,66 @@ def rewrite_tensor_options(fname, pname):
return code, xname


def get_xla_wrapper(orig_sig, ctx):
tree = _PARSER.parse(orig_sig)
xtree = _XPARSER.parse(orig_sig)
mapsig = create_map_sig(xtree, orig_sig)
rwsig = rewrite_signature(orig_sig, _TYPE_NSMAP)
rwxtree = _XPARSER.parse(rwsig)
params = get_parameters(tree)
fnopts = _FUNCTION_OPTIONS.get(mapsig, None)
ref_param = get_reference_param(params, fnopts=fnopts)
def generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts):
rtype = tree.children[0]
num_outputs = None
if type_core(rtype) == 'std::tuple':
num_outputs = len(tuple_type_list(rtype))

# There are a few functions with the same function name but different
# parameter list. Generate a unique XL function name here.
def gen_fnname(x):
if ctx.gen_class_mode:
return 'AtenXlaTypeBase::{}'.format(x)
post = ''
if x in ctx.defdb:
post = '_{}'.format(ctx.defdb[x])
ctx.defdb[x] += 1
else:
ctx.defdb[x] = 1
return 'xla_' + x + post
code = '{} {}{{\n'.format(sig, 'const ' if ctx.gen_class_mode else '')
code += generate_entry_debug_code(tree, fname, params)

sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname)
if is_blacklisted_fn(fname, mapsig):
return None
param_vars = []
for p in params:
pname = param_name(p)
param_vars.append(pname)

if fnopts.outfn_template is not None:
mdict = {}
for i, pname in enumerate(param_vars):
mdict[str(i)] = pname
fcall = fnopts.outfn_template.substitute(mdict)
else:
m = re.match(r'(.*)_out$', fname)
assert m is not None, fname
core_fname = m.group(1)
out_count = num_outputs if num_outputs is not None else 1
fcall = '{}('.format(core_fname)
for i in range(out_count, len(param_vars)):
if i > out_count:
fcall += ', '
fcall += param_vars[i]
fcall += ')'

if num_outputs is None:
code += ' {} = {};\n'.format(param_vars[0], fcall)
code += generate_exit_debug_code(tree, fname, param_vars[0], params,
param_vars)
code += ' return {};\n'.format(param_vars[0])
else:
code += ' std::tie('
for i in range(0, num_outputs):
if i > 0:
code += ', '
code += param_vars[i]
code += ') = {};\n'.format(fcall)
code += generate_exit_debug_code(tree, fname, param_vars[0:num_outputs],
params, param_vars)
code += ' return {}('.format(get_return_type_str(rwxtree, rwsig))
for i in range(0, num_outputs):
if i > 0:
code += ', '
code += param_vars[i]
code += ');\n'
code += '}'
return code


def generate_aten_to_xla(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts):
ref_param = get_reference_param(params, fnopts=fnopts)

code = '{} {}{{\n'.format(sig, 'const ' if ctx.gen_class_mode else '')
code += generate_entry_debug_code(tree, fname, params, ref_param)
code += generate_entry_debug_code(tree, fname, params)
xla_ref_param = param_name(ref_param) if ref_param else None
tfetcher = TensorFetcher('xlatens')
param_vars = []
Expand Down Expand Up @@ -796,13 +857,47 @@ def gen_fnname(x):
if result_assign:
code += (' static_cast<void>({}); // Avoid warnings in case not '
'used\n'.format(_RESULT_NAME))
code += generate_exit_debug_code(tree, fname,
_RESULT_NAME if result_assign else None,
params, param_vars, ref_param)
code += generate_exit_debug_code(
tree, fname, _RESULT_NAME if result_assign else None, params, param_vars)
code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname,
_RESULT_NAME if result_assign else None, params,
param_vars, ref_param, fnopts)
code += '}'
return code


def get_xla_wrapper(orig_sig, ctx):
tree = _PARSER.parse(orig_sig)
xtree = _XPARSER.parse(orig_sig)
mapsig = create_map_sig(xtree, orig_sig)
rwsig = rewrite_signature(orig_sig, _TYPE_NSMAP)
rwxtree = _XPARSER.parse(rwsig)
params = get_parameters(tree)
fnopts = _FUNCTION_OPTIONS.get(mapsig, None)

# There are a few functions with the same function name but different
# parameter list. Generate a unique XL function name here.
def gen_fnname(x):
if ctx.gen_class_mode:
return 'AtenXlaTypeBase::{}'.format(x)
post = ''
if x in ctx.defdb:
post = '_{}'.format(ctx.defdb[x])
ctx.defdb[x] += 1
else:
ctx.defdb[x] = 1
return 'xla_' + x + post

sig, fname, xfname = get_function_signature(rwxtree, rwsig, gen_fnname)
if is_blacklisted_fn(fname, mapsig):
return None
ofnopts = get_outfn_options(fname, mapsig)
if ofnopts is not None:
code = generate_aten_out(ctx, tree, rwxtree, fname, sig, rwsig, params,
ofnopts)
else:
code = generate_aten_to_xla(ctx, tree, rwxtree, fname, sig, rwsig, params,
fnopts)
return FuncGen(
tree=tree,
xtree=xtree,
Expand Down
6 changes: 0 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,6 @@ at::Tensor AtenXlaType::arange(at::Scalar start, at::Scalar end,
xla_options.get_scalar_type()));
}

at::Tensor& AtenXlaType::arange_out(at::Tensor& out, at::Scalar start,
at::Scalar end, at::Scalar step) const {
out = arange(start, end, step, out.options());
return out;
}

at::Tensor AtenXlaType::argmax(const at::Tensor& self,
c10::optional<int64_t> dim, bool keepdim) const {
return dim ? bridge::AtenFromXlaTensor(
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,6 @@ class AtenXlaType : public AtenXlaTypeBase {
at::Tensor arange(at::Scalar start, at::Scalar end, at::Scalar step,
const at::TensorOptions& options) const override;

at::Tensor& arange_out(at::Tensor& out, at::Scalar start, at::Scalar end,
at::Scalar step) const override;

at::Tensor argmax(const at::Tensor& self, c10::optional<int64_t> dim,
bool keepdim) const override;

Expand Down