Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class ArgTemplate(string.Template):
idpattern = r'[a-z0-9_]+'


FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig, leaf')
FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig, leaf, math')

FuncGen = namedtuple_with_defaults(
'FuncGen',
'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig, leaf'
'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig, leaf, math'
)

FuncOpts = namedtuple_with_defaults(
Expand Down Expand Up @@ -80,12 +80,12 @@ class ArgTemplate(string.Template):
_XPARSER = lark.Lark(
_GRAMMAR, parser='lalr', propagate_positions=True, keep_all_tokens=True)

# _FN_FULL_OVERRIDE/_FN_BLACKLIST takes either name or mapsig.
# _FN_AUTOGRAD_XLA/_FN_BLACKLIST takes either name or mapsig.
_FN_BLACKLIST = set([])

# List of non-leaf ops we want to override both forward + backward.
# TODO(https://github.com/pytorch/pytorch/issues/39959)
_FN_FULL_OVERRIDE = set([
_FN_AUTOGRAD_XLA = set([
'max_pool2d(Tensor, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool) -> Tensor',
'max_pool3d(Tensor, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool) -> Tensor',
])
Expand Down Expand Up @@ -927,7 +927,8 @@ def gen_fnname(x):
mapsig=mapsig,
funsig=create_stdfunc_sig(rwxtree, rwsig),
aten_sig=fndef.aten_sig,
leaf=fndef.leaf)
leaf=fndef.leaf,
math=fndef.math)


def is_tensor_api(fndef):
Expand All @@ -942,7 +943,8 @@ def create_funcdef(fndef, jdata):
return FuncDef(
cpp_sig=fndef,
aten_sig=fields['schema'],
leaf=fields.get('compound', 'false') == 'false')
leaf=fields.get('compound', 'False') == 'False',
math=fields.get('has_math_kernel', 'False') == 'True')


def extract_functions(path):
Expand Down Expand Up @@ -1010,10 +1012,10 @@ def generate_unboxed(aten_sig, overload, override_fn):

def generate_registrations(fgens, overrides):
aten_code = 'TORCH_LIBRARY_IMPL(aten, XLA, m) {\n'
preautograd_code = 'TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {\n'
autogradxla_code = 'TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {\n'
overridden = set()
for fgen in fgens:
if not is_overrideable(fgen):
if not requires_registration(fgen):
continue
mapsig_key = get_mapsig_key(fgen.mapsig)
if mapsig_key in overrides:
Expand All @@ -1025,30 +1027,34 @@ def generate_registrations(fgens, overrides):
pos = fgen.funsig.find('(')
overload = fgen.funsig[:pos] + ' (*)' + fgen.funsig[pos:]
unboxed = generate_unboxed(fgen.aten_sig, overload, override_fn)
if fgen.mapsig in _FN_FULL_OVERRIDE:
preautograd_code += unboxed
if fgen.mapsig in _FN_AUTOGRAD_XLA:
autogradxla_code += unboxed
else:
aten_code += unboxed
return aten_code + '\n}\n' + preautograd_code + '\n}\n', overridden
return aten_code + '\n}\n' + autogradxla_code + '\n}\n', overridden


# XLA is only able to override leaf ops and whitelisted non-leaf ops.
def is_overrideable(fgen):
return fgen.leaf or fgen.mapsig in _FN_FULL_OVERRIDE or fgen.func in _FN_FULL_OVERRIDE
# XLA is required to register kernel to leaf nodes when there's no math
# kernel provided in PyTorch core.
# For other non-leaf nodes, PyTorch covers both forward and backward for
# them. But XLA can still optionally override them as necessary.
def requires_registration(fgen):
return (fgen.leaf and not fgen.math
) or fgen.mapsig in _FN_AUTOGRAD_XLA or fgen.func in _FN_AUTOGRAD_XLA


def generate_functions(fgens):
code = ''
for fgen in fgens:
if fgen.code and is_overrideable(fgen):
if fgen.code and requires_registration(fgen):
code += '{}\n\n'.format(fgen.code)
return code


def generate_class_functions(fgens):
code = ''
for fgen in fgens:
if fgen.code and is_overrideable(fgen):
if fgen.code and requires_registration(fgen):
code += ' static {};\n'.format(fgen.rwsig)
return code

Expand Down
6 changes: 3 additions & 3 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,7 @@ TEST_F(AtenXlaTensorTest, TestGroupNorm) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::native_group_norm",
ExpectCounterChanged("xla::native_batch_norm",
cpp_test::GetIgnoredCounters());
}
}
Expand Down Expand Up @@ -1582,9 +1582,9 @@ TEST_F(AtenXlaTensorTest, TestGroupNormBackward) {
device, testfn,
/*rtol=*/1e-3, /*atol=*/1e-4);
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::native_group_norm",
ExpectCounterChanged("xla::native_batch_norm",
cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::native_group_norm_backward",
ExpectCounterChanged("xla::native_batch_norm_backward",
cpp_test::GetIgnoredCounters());
});
}
Expand Down
25 changes: 0 additions & 25 deletions torch_xla/csrc/aten_tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,6 @@ at::Tensor& celu_(at::Tensor& self, at::Scalar alpha) {
return at::elu_(self, alpha, at::Scalar(1.0), at::Scalar(inv_alpha));
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, int64_t N, int64_t C, int64_t HxW,
int64_t group, double eps) {
auto input_shape = input.sizes();
at::Tensor input_reshaped = input.view({1, N * group, N ? -1 : 1});
auto outputs = at::native_batch_norm(
input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
/*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
at::Tensor out = std::get<0>(outputs);
out = out.view(input_shape);
std::vector<int64_t> affine_param_shape(input.dim(), 1);
affine_param_shape[1] = C;
if (torch_xla::IsDefined(weight) && torch_xla::IsDefined(bias)) {
out = bias.value()
.view(affine_param_shape)
.addcmul(out, weight.value().view(affine_param_shape), 1);
} else if (torch_xla::IsDefined(weight)) {
out = out.mul(weight.value().view(affine_param_shape));
} else if (torch_xla::IsDefined(bias)) {
out = out.add(bias.value().view(affine_param_shape));
}
return std::make_tuple(out, std::get<1>(outputs), std::get<2>(outputs));
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean,
const at::Tensor& rstd, const c10::optional<at::Tensor>& weight, int64_t N,
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/csrc/aten_tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ at::Tensor celu(const at::Tensor& self, at::Scalar alpha);

at::Tensor& celu_(at::Tensor& self, at::Scalar alpha);

std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, int64_t N, int64_t C, int64_t HxW,
int64_t group, double eps);

std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean,
const at::Tensor& rstd, const c10::optional<at::Tensor>& weight, int64_t N,
Expand Down
9 changes: 0 additions & 9 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2217,15 +2217,6 @@ AtenXlaType::native_batch_norm_backward(
: undefined);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenXlaType::native_group_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, int64_t N, int64_t C, int64_t HxW,
int64_t group, double eps) {
XLA_FN_COUNTER("xla::");
return aten_tensor_ops::native_group_norm(input, weight, bias, N, C, HxW,
group, eps);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor>
AtenXlaType::native_group_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean,
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,6 @@ class AtenXlaType {
bool train, double eps,
std::array<bool, 3> output_mask);

static std::tuple<at::Tensor, at::Tensor, at::Tensor> native_group_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, int64_t N, int64_t C, int64_t HxW,
int64_t group, double eps);

static std::tuple<at::Tensor, at::Tensor, at::Tensor>
native_group_norm_backward(const at::Tensor& grad_out,
const at::Tensor& input, const at::Tensor& mean,
Expand Down