diff --git a/scripts/gen.py b/scripts/gen.py index e35c6030e9e4..4c775d18a573 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -194,7 +194,7 @@ class AtenXlaTypeDefault {{ #include "torch_xla/csrc/aten_xla_type_default.h" #include -#include +#include #include #include "tensorflow/compiler/xla/xla_client/debug_macros.h" @@ -944,7 +944,7 @@ def parse_local_overrides(path): def generate_registrations(fgens, overrides): code = 'void RegisterAtenTypeFunctions() {\n' - code += ' auto& dispatch = at::globalATenDispatch();\n' + code += ' auto dispatch = torch::RegisterOperators()\n' overridden = set() for fgen in fgens: mapsig_key = get_mapsig_key(fgen.mapsig) @@ -955,9 +955,11 @@ def generate_registrations(fgens, overrides): override_fn = fgen.xfunc if fgen.code else None if override_fn: code += ( - ' dispatch.registerOp<{}>(at::Backend::XLA, "{}", &{});\n'.format( - fgen.funsig, fgen.aten_sig, override_fn)) - return code + '}\n', overridden + ' .op(torch::RegisterOperators::options().schema("{}")\n' + ' .impl_unboxedOnlyKernel<{}, &{}>(at::TensorTypeId::XLATensorId)\n' + ' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format( + fgen.aten_sig, fgen.funsig, override_fn, override_fn, fgen.aten_sig)) + return code + ';\n}\n', overridden def generate_functions(fgens):