From 4f65860a6e3fcb19c74bb66110b837c76a8c5857 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Fri, 13 Sep 2019 16:02:40 -0700 Subject: [PATCH 1/2] move to c10 --- scripts/gen.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/scripts/gen.py b/scripts/gen.py index e35c6030e9e4..3fea82a05e9b 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -194,7 +194,7 @@ class AtenXlaTypeDefault {{ #include "torch_xla/csrc/aten_xla_type_default.h" #include -#include +#include #include #include "tensorflow/compiler/xla/xla_client/debug_macros.h" @@ -944,7 +944,7 @@ def parse_local_overrides(path): def generate_registrations(fgens, overrides): code = 'void RegisterAtenTypeFunctions() {\n' - code += ' auto& dispatch = at::globalATenDispatch();\n' + code += ' static auto dispatch = torch::RegisterOperators()\n' overridden = set() for fgen in fgens: mapsig_key = get_mapsig_key(fgen.mapsig) @@ -955,9 +955,11 @@ def generate_registrations(fgens, overrides): override_fn = fgen.xfunc if fgen.code else None if override_fn: code += ( - ' dispatch.registerOp<{}>(at::Backend::XLA, "{}", &{});\n'.format( - fgen.funsig, fgen.aten_sig, override_fn)) - return code + '}\n', overridden + ' .op(torch::RegisterOperators::options().schema("{}")\n' + ' .impl_unboxedOnlyKernel<{}, &{}>(at::TensorTypeId::XLATensorId)\n' + ' .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))\n'.format( + fgen.aten_sig, fgen.funsig, override_fn, override_fn, fgen.aten_sig)) + return code + ';\n}\n', overridden def generate_functions(fgens): From ee02ccbee2f27676d7d181c9f32cce903c68c167 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Fri, 13 Sep 2019 17:06:03 -0700 Subject: [PATCH 2/2] remove static --- scripts/gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/gen.py b/scripts/gen.py index 3fea82a05e9b..4c775d18a573 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -944,7 +944,7 @@ def parse_local_overrides(path): def generate_registrations(fgens, overrides): code = 'void RegisterAtenTypeFunctions() {\n' - code += ' static auto dispatch = torch::RegisterOperators()\n' + code += ' auto dispatch = torch::RegisterOperators()\n' overridden = set() for fgen in fgens: mapsig_key = get_mapsig_key(fgen.mapsig)