diff --git a/scripts/gen.py b/scripts/gen.py index b31391a5380b..9a00a67b3f71 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -25,11 +25,11 @@ class ArgTemplate(string.Template): idpattern = r'[a-z0-9_]+' -FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig, leaf') +FuncDef = namedtuple_with_defaults('FuncDef', 'cpp_sig, aten_sig, leaf, math') FuncGen = namedtuple_with_defaults( 'FuncGen', - 'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig, leaf' + 'tree, xtree, rwxtree, func, xfunc, code, sig, rwsig, cppsig, funsig, mapsig, aten_sig, leaf, math' ) FuncOpts = namedtuple_with_defaults( @@ -80,12 +80,12 @@ class ArgTemplate(string.Template): _XPARSER = lark.Lark( _GRAMMAR, parser='lalr', propagate_positions=True, keep_all_tokens=True) -# _FN_FULL_OVERRIDE/_FN_BLACKLIST takes either name or mapsig. +# _FN_AUTOGRAD_XLA/_FN_BLACKLIST takes either name or mapsig. _FN_BLACKLIST = set([]) # List of non-leaf ops we want to override both forward + backward. # TODO(https://github.com/pytorch/pytorch/issues/39959) -_FN_FULL_OVERRIDE = set([ +_FN_AUTOGRAD_XLA = set([ 'max_pool2d(Tensor, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool) -> Tensor', 'max_pool3d(Tensor, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool) -> Tensor', ]) @@ -927,7 +927,8 @@ def gen_fnname(x): mapsig=mapsig, funsig=create_stdfunc_sig(rwxtree, rwsig), aten_sig=fndef.aten_sig, - leaf=fndef.leaf) + leaf=fndef.leaf, + math=fndef.math) def is_tensor_api(fndef): @@ -942,7 +943,8 @@ def create_funcdef(fndef, jdata): return FuncDef( cpp_sig=fndef, aten_sig=fields['schema'], - leaf=fields.get('compound', 'false') == 'false') + leaf=fields.get('compound', 'False') == 'False', + math=fields.get('has_math_kernel', 'False') == 'True') def extract_functions(path): @@ -1010,10 +1012,10 @@ def generate_unboxed(aten_sig, overload, override_fn): def generate_registrations(fgens, overrides): aten_code = 'TORCH_LIBRARY_IMPL(aten, XLA, m) {\n' - preautograd_code = 'TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {\n' + autogradxla_code = 'TORCH_LIBRARY_IMPL(aten, AutogradXLA, m) {\n' overridden = set() for fgen in fgens: - if not is_overrideable(fgen): + if not requires_registration(fgen): continue mapsig_key = get_mapsig_key(fgen.mapsig) if mapsig_key in overrides: @@ -1025,22 +1027,26 @@ def generate_registrations(fgens, overrides): pos = fgen.funsig.find('(') overload = fgen.funsig[:pos] + ' (*)' + fgen.funsig[pos:] unboxed = generate_unboxed(fgen.aten_sig, overload, override_fn) - if fgen.mapsig in _FN_FULL_OVERRIDE: - preautograd_code += unboxed + if fgen.mapsig in _FN_AUTOGRAD_XLA: + autogradxla_code += unboxed else: aten_code += unboxed - return aten_code + '\n}\n' + preautograd_code + '\n}\n', overridden + return aten_code + '\n}\n' + autogradxla_code + '\n}\n', overridden -# XLA is only able to override leaf ops and whitelisted non-leaf ops. -def is_overrideable(fgen): - return fgen.leaf or fgen.mapsig in _FN_FULL_OVERRIDE or fgen.func in _FN_FULL_OVERRIDE +# XLA is required to register kernel to leaf nodes when there's no math +# kernel provided in PyTorch core. +# For other non-leaf nodes, PyTorch covers both forward and backward for +# them. But XLA can still optionally override them as necessary. +def requires_registration(fgen): + return (fgen.leaf and not fgen.math + ) or fgen.mapsig in _FN_AUTOGRAD_XLA or fgen.func in _FN_AUTOGRAD_XLA def generate_functions(fgens): code = '' for fgen in fgens: - if fgen.code and is_overrideable(fgen): + if fgen.code and requires_registration(fgen): code += '{}\n\n'.format(fgen.code) return code @@ -1048,7 +1054,7 @@ def generate_functions(fgens): def generate_class_functions(fgens): code = '' for fgen in fgens: - if fgen.code and is_overrideable(fgen): + if fgen.code and requires_registration(fgen): code += ' static {};\n'.format(fgen.rwsig) return code diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index f5650f25680d..ffc1ad1973d8 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -1550,7 +1550,7 @@ TEST_F(AtenXlaTensorTest, TestGroupNorm) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_group_norm", + ExpectCounterChanged("xla::native_batch_norm", cpp_test::GetIgnoredCounters()); } } @@ -1582,9 +1582,9 @@ TEST_F(AtenXlaTensorTest, TestGroupNormBackward) { device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_group_norm", + ExpectCounterChanged("xla::native_batch_norm", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_group_norm_backward", + ExpectCounterChanged("xla::native_batch_norm_backward", cpp_test::GetIgnoredCounters()); }); } diff --git a/torch_xla/csrc/aten_tensor_ops.cpp b/torch_xla/csrc/aten_tensor_ops.cpp index 48b0306b6cc3..933931711545 100644 --- a/torch_xla/csrc/aten_tensor_ops.cpp +++ b/torch_xla/csrc/aten_tensor_ops.cpp @@ -18,31 +18,6 @@ at::Tensor& celu_(at::Tensor& self, at::Scalar alpha) { return at::elu_(self, alpha, at::Scalar(1.0), at::Scalar(inv_alpha)); } -std::tuple native_group_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, - int64_t group, double eps) { - auto input_shape = input.sizes(); - at::Tensor input_reshaped = input.view({1, N * group, N ? -1 : 1}); - auto outputs = at::native_batch_norm( - input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{}, - /*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps); - at::Tensor out = std::get<0>(outputs); - out = out.view(input_shape); - std::vector affine_param_shape(input.dim(), 1); - affine_param_shape[1] = C; - if (torch_xla::IsDefined(weight) && torch_xla::IsDefined(bias)) { - out = bias.value() - .view(affine_param_shape) - .addcmul(out, weight.value().view(affine_param_shape), 1); - } else if (torch_xla::IsDefined(weight)) { - out = out.mul(weight.value().view(affine_param_shape)); - } else if (torch_xla::IsDefined(bias)) { - out = out.add(bias.value().view(affine_param_shape)); - } - return std::make_tuple(out, std::get<1>(outputs), std::get<2>(outputs)); -} - std::tuple native_group_norm_backward( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean, const at::Tensor& rstd, const c10::optional& weight, int64_t N, diff --git a/torch_xla/csrc/aten_tensor_ops.h b/torch_xla/csrc/aten_tensor_ops.h index a8f67e3f8026..4a53bdd23821 100644 --- a/torch_xla/csrc/aten_tensor_ops.h +++ b/torch_xla/csrc/aten_tensor_ops.h @@ -8,11 +8,6 @@ at::Tensor celu(const at::Tensor& self, at::Scalar alpha); at::Tensor& celu_(at::Tensor& self, at::Scalar alpha); -std::tuple native_group_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, - int64_t group, double eps); - std::tuple native_group_norm_backward( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean, const at::Tensor& rstd, const c10::optional& weight, int64_t N, diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 0bd6c3de5025..a0a5dbabef33 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2217,15 +2217,6 @@ AtenXlaType::native_batch_norm_backward( : undefined); } -std::tuple AtenXlaType::native_group_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, - int64_t group, double eps) { - XLA_FN_COUNTER("xla::"); - return aten_tensor_ops::native_group_norm(input, weight, bias, N, C, HxW, - group, eps); -} - std::tuple AtenXlaType::native_group_norm_backward( const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean, diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index f3b4aec487cf..9fa893342fd7 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -693,11 +693,6 @@ class AtenXlaType { bool train, double eps, std::array output_mask); - static std::tuple native_group_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, - int64_t group, double eps); - static std::tuple native_group_norm_backward(const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean,