From 3ead02dbc017e3fe2643b4de4fd11e411de84bed Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 13 Sep 2019 11:15:27 -0400 Subject: [PATCH 1/5] Multiple dispatch compatibility for copy_ in XLA. Signed-off-by: Edward Z. Yang --- torch_xla/csrc/aten_xla_type.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7a5104847ec3..02f0564e2f20 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -866,12 +866,26 @@ AtenXlaType::convolution_backward_overrideable( at::Tensor& AtenXlaType::copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking) { - XLATensor self_tensor = bridge::GetXlaTensor(self); + c10::optional self_tensor = bridge::TryGetXlaTensor(self); c10::optional src_tensor = bridge::TryGetXlaTensor(src); - if (src_tensor) { - XLATensor::copy_(self_tensor, *src_tensor); + + if (!src_tensor) { + TORCH_ASSERT(self_tensor); + self_tensor.SetTensor(CopyTensor(src, self_tensor->scalar_type())); + } else if (!self_tensor) { + // TODO: Is self_tensor good enough? I don't think so... therefore + // the hack below: + // + // Do not mark the tensor creation as writeable to not discard the XLA tensor + // device context, but make a copy to avoid core data to be shared. + std::vector tensors = {src}; + auto xla_tensors = bridge::XlaCreateTensorList(tensors); + // Hack in an overwrite of a const tensor. + at::Tensor t = CopyTensor(xla_tensors.front(), self.scalar_type()); + const_cast(self).unsafeGetTensorImpl()->shallow_copy_from( + t.getIntrusivePtr()); } else { - self_tensor.SetTensor(CopyTensor(src, self.scalar_type())); + XLATensor::copy_(self_tensor, *src_tensor); } return self; } From 25cc271688e83b6d0387e3f4a1558ee51e0f007b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 13 Sep 2019 11:21:48 -0400 Subject: [PATCH 2/5] Add upstream patch. Signed-off-by: Edward Z. Yang --- torch_patches/25653.diff | 5558 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 5558 insertions(+) create mode 100644 torch_patches/25653.diff diff --git a/torch_patches/25653.diff b/torch_patches/25653.diff new file mode 100644 index 000000000000..40a277fbeebc --- /dev/null +++ b/torch_patches/25653.diff @@ -0,0 +1,5558 @@ +From c0f199b932527a79f9f48674c7345b789a37967a Mon Sep 17 00:00:00 2001 +From: "Edward Z. Yang" +Date: Fri, 13 Sep 2019 11:05:42 -0400 +Subject: [PATCH] Implement multiple dispatch + +Instead of considering only the TensorTypeSet of the first argument, +we collect all Tensor and TensorList arguments and union them together +before computing the dispatch type id. + +A minor bit of refactoring I had to do to get here was move the IterArgs +functionality in torch/csrc/utils/variadic.h into ATen/core. There's +some refactoring due on that file too (it has copies of some C++ helper +pieces which already live in c10). + +There is a little bit of a hack in the code generator to turn 'self' +arguments into '*this'. I think this may be duplicated with some +logic somewhere else but I have to double check. + +Signed-off-by: Edward Z. Yang + +ghstack-source-id: 18d37ab48c63a7ee04ce19ecd3b803ba7096afca +Pull Request resolved: https://github.com/pytorch/pytorch/pull/25653 +--- + aten/src/ATen/SparseTensorUtils.h | 2 + + aten/src/ATen/core/ATenDispatch.cpp | 28 + + aten/src/ATen/core/ATenDispatch.h | 6 +- + aten/src/ATen/core/TensorBody.h | 11 +- + aten/src/ATen/core/TensorMethods.h | 982 +++++++++++---------- + aten/src/ATen/core/Variadic.h | 74 ++ + aten/src/ATen/core/aten_interned_strings.h | 1 - + aten/src/ATen/function_wrapper.py | 55 +- + aten/src/ATen/native/BinaryOps.cpp | 59 +- + aten/src/ATen/native/LegacyBridge.cpp | 79 -- + aten/src/ATen/native/native_functions.yaml | 147 ++- + aten/src/ATen/native/sparse/SparseTensorMath.cpp | 102 ++- + aten/src/ATen/native/sparse/SparseTensorMath.h | 11 + + .../ATen/native/sparse/cuda/SparseCUDATensor.cpp | 2 +- + .../native/sparse/cuda/SparseCUDATensorMath.cu | 52 +- + aten/src/ATen/templates/TensorBody.h | 11 +- + aten/src/ATen/templates/TensorMethods.h | 22 + + c10/core/TensorTypeId.h | 6 +- + test/test_nn.py | 3 +- + test/test_sparse.py | 6 +- + tools/autograd/templates/Functions.cpp | 2 +- + torch/csrc/utils/variadic.h | 64 +- + torch/lib/c10d/ProcessGroupGloo.cpp | 8 + + 23 files changed, 948 insertions(+), 785 deletions(-) + create mode 100644 aten/src/ATen/core/Variadic.h + create mode 100644 aten/src/ATen/native/sparse/SparseTensorMath.h + +diff --git a/aten/src/ATen/SparseTensorUtils.h b/aten/src/ATen/SparseTensorUtils.h +index ecc52b2cb3..45aa79eef9 100644 +--- a/aten/src/ATen/SparseTensorUtils.h ++++ b/aten/src/ATen/SparseTensorUtils.h +@@ -1,3 +1,5 @@ ++#pragma once ++ + #include + #include + +diff --git a/aten/src/ATen/core/ATenDispatch.cpp b/aten/src/ATen/core/ATenDispatch.cpp +index 26deaef09a..50ade874af 100644 +--- a/aten/src/ATen/core/ATenDispatch.cpp ++++ b/aten/src/ATen/core/ATenDispatch.cpp +@@ -7,4 +7,32 @@ ATenDispatch & globalATenDispatch() { + return singleton; + } + ++void* ATenOpTable::getFallbackOp(TensorTypeId tid) const { ++ // TODO: an alternate strategy here would be to mask out the dead key ++ // and then redispatch gain (automatic delegation). I haven't done this ++ // for now to make it easier to smoke out error cases. ++ if (function_table_[static_cast(TensorTypeId::UndefinedTensorId)] == nullptr) { ++ // If there is no fallback dispatch, and dispatch failed because we didn't ++ // find any valid keys to dispatch on, this usually means the user gave ++ // us a non-empty list of tensors. So report a better error in this case. ++ // TODO: Maybe we should reword this error message ++ if (tid == TensorTypeId::UndefinedTensorId) { ++ TORCH_CHECK(false, "expected a non-empty list of Tensors") ++ } ++ std::ostringstream oss; ++ bool first = true; ++ for (int64_t i = 0; i < static_cast(TensorTypeId::NumTensorIds); i++) { ++ if (function_table_[i] != nullptr) { ++ if (!first) oss << ", "; ++ oss << toString(static_cast(i)); ++ first = false; ++ } ++ } ++ TORCH_CHECK(false, ++ "No function is registered for schema ", schema_, " on tensor type ", toString(tid), ++ "; available functions are ", oss.str()); ++ } ++ return function_table_[static_cast(TensorTypeId::UndefinedTensorId)]; ++} ++ + } // namespace at +diff --git a/aten/src/ATen/core/ATenDispatch.h b/aten/src/ATen/core/ATenDispatch.h +index ba2940feb3..af16f880fb 100644 +--- a/aten/src/ATen/core/ATenDispatch.h ++++ b/aten/src/ATen/core/ATenDispatch.h +@@ -58,6 +58,8 @@ class CAFFE2_API ATenOpTable { + function_table_[static_cast(tid)] = fn; + } + ++ void* getFallbackOp(TensorTypeId tid) const; ++ + void* getOp(TensorTypeId tid) const { + // You might think we can minorly optimize this further by maintaining a + // bitmask of registered operator keys, so we don't select dispatch ids +@@ -65,9 +67,7 @@ class CAFFE2_API ATenOpTable { + // get a Variable CPUTensor, if there is no variable registration, you'll + // fall back to the CPU implementation. Is this what you want? Unlikely... + if (function_table_[static_cast(tid)] == nullptr) { +- TORCH_CHECK(function_table_[static_cast(TensorTypeId::UndefinedTensorId)] != nullptr, +- "No function is registered for schema ", schema_, " on tensor type ", toString(tid)); +- return function_table_[static_cast(TensorTypeId::UndefinedTensorId)]; ++ return getFallbackOp(tid); + } + return function_table_[static_cast(tid)]; + } +diff --git a/aten/src/ATen/core/TensorBody.h b/aten/src/ATen/core/TensorBody.h +index ded8489538..d150bb57bd 100644 +--- a/aten/src/ATen/core/TensorBody.h ++++ b/aten/src/ATen/core/TensorBody.h +@@ -919,7 +919,7 @@ protected: + }; + + namespace detail { +-// Helper creator for Tensor clas which doesn't requires the users to pass ++// Helper creator for Tensor class which doesn't requires the users to pass + // in an intrusive_ptr instead it just converts the argument passed to + // requested intrusive_ptr type. + template +@@ -927,15 +927,6 @@ Tensor make_tensor(Args&&... args) { + return Tensor(c10::make_intrusive(std::forward(args)...)); + } + +-inline TensorTypeSet infer_tensor_type_set(const Tensor & tl) { +- return tl.type_set(); +-} +- +-inline TensorTypeSet infer_tensor_type_set(TensorList tl) { +- TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); +- return tl[0].type_set(); +-} +- + } // namespace detail + + static inline TensorTypeId legacyExtractTypeId(const Tensor& t) { +diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h +index 5c9cdda828..03b7de6675 100644 +--- a/aten/src/ATen/core/TensorMethods.h ++++ b/aten/src/ATen/core/TensorMethods.h +@@ -11,6 +11,7 @@ + #if !defined(CAFFE2_IS_XPLAT_BUILD) + #include + #endif ++#include + #include + #ifdef USE_STATIC_DISPATCH + #include +@@ -21,6 +22,27 @@ + + namespace at { + ++namespace detail { ++ ++struct MultiDispatchTensorTypeSet : IterArgs { ++ TensorTypeSet ts; ++ void operator()(const at::Tensor& x) { ++ ts = ts | x.type_set(); ++ } ++ void operator()(at::ArrayRef xs) { ++ for (const auto& x : xs) { ++ ts = ts | x.type_set(); ++ } ++ } ++}; ++ ++template ++TensorTypeSet multi_dispatch_tensor_type_set(Args&&... args) { ++ return MultiDispatchTensorTypeSet().apply(std::forward(args)...).ts; ++} ++ ++} ++ + struct Quantizer; + // This is temporary typedef to enable Quantizer in aten native function API + // we'll remove them when we are actually exposing Quantizer class +@@ -62,7 +84,7 @@ inline void Tensor::backward(const Tensor & gradient, bool keep_graph, bool crea + TypeDefault::backward(const_cast(*this), gradient, keep_graph, create_graph); + #else + static auto table = globalATenDispatch().getOpTable("aten::backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void"); +- return table->getOp(type_set())(const_cast(*this), gradient, keep_graph, create_graph); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, gradient))(const_cast(*this), gradient, keep_graph, create_graph); + #endif + } + inline void Tensor::set_data(const Tensor & new_data) const { +@@ -70,7 +92,7 @@ inline void Tensor::set_data(const Tensor & new_data) const { + TypeDefault::set_data(const_cast(*this), new_data); + #else + static auto table = globalATenDispatch().getOpTable("aten::set_data(Tensor(a!) self, Tensor new_data) -> void"); +- return table->getOp(type_set())(const_cast(*this), new_data); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, new_data))(const_cast(*this), new_data); + #endif + } + inline Tensor Tensor::data() const { +@@ -78,7 +100,7 @@ inline Tensor Tensor::data() const { + return TypeDefault::data(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::data(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -87,7 +109,7 @@ inline Tensor & Tensor::names_(c10::optional names) const { + return TypeDefault::names_(const_cast(*this), names); + #else + static auto table = globalATenDispatch().getOpTable("aten::names_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)"); +- return table->getOp)>(type_set())(const_cast(*this), names); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); + #endif + } + #endif +@@ -97,7 +119,7 @@ inline Tensor Tensor::renamed(c10::optional names) const { + return TypeDefault::renamed(const_cast(*this), names); + #else + static auto table = globalATenDispatch().getOpTable("aten::renamed(Tensor(a) self, Dimname[]? names) -> Tensor(a)"); +- return table->getOp)>(type_set())(const_cast(*this), names); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); + #endif + } + #endif +@@ -107,7 +129,7 @@ inline Tensor Tensor::align_to(DimnameList names) const { + return TypeDefault::align_to(const_cast(*this), names); + #else + static auto table = globalATenDispatch().getOpTable("aten::align_to(Tensor(a) self, DimnameList names) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), names); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); + #endif + } + #endif +@@ -117,7 +139,7 @@ inline Tensor Tensor::align_as(const Tensor & other) const { + return TypeDefault::align_as(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::align_as(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + #endif +@@ -127,7 +149,7 @@ inline Tensor Tensor::refine_names(DimnameList names) const { + return TypeDefault::refine_names(const_cast(*this), names); + #else + static auto table = globalATenDispatch().getOpTable("aten::refine_names(Tensor(a) self, DimnameList names) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), names); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); + #endif + } + #endif +@@ -136,7 +158,7 @@ inline Tensor Tensor::abs() const { + return TypeDefault::abs(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::abs(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::abs_() const { +@@ -150,7 +172,7 @@ inline Tensor & Tensor::abs_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::abs_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::acos() const { +@@ -158,7 +180,7 @@ inline Tensor Tensor::acos() const { + return TypeDefault::acos(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::acos(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::acos_() const { +@@ -172,7 +194,7 @@ inline Tensor & Tensor::acos_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::acos_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const { +@@ -189,7 +211,7 @@ inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, alpha); + #endif + } + inline Tensor & Tensor::add_(const Tensor & other, Scalar alpha) const { +@@ -206,7 +228,7 @@ inline Tensor & Tensor::add_(const Tensor & other, Scalar alpha) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, alpha); + #endif + } + inline Tensor Tensor::add(Scalar other, Scalar alpha) const { +@@ -214,7 +236,7 @@ inline Tensor Tensor::add(Scalar other, Scalar alpha) const { + return TypeDefault::add(const_cast(*this), other, alpha); + #else + static auto table = globalATenDispatch().getOpTable("aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other, alpha); + #endif + } + inline Tensor & Tensor::add_(Scalar other, Scalar alpha) const { +@@ -222,7 +244,7 @@ inline Tensor & Tensor::add_(Scalar other, Scalar alpha) const { + return TypeDefault::add_(const_cast(*this), other, alpha); + #else + static auto table = globalATenDispatch().getOpTable("aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other, alpha); + #endif + } + inline Tensor Tensor::addmv(const Tensor & mat, const Tensor & vec, Scalar beta, Scalar alpha) const { +@@ -236,7 +258,7 @@ inline Tensor Tensor::addmv(const Tensor & mat, const Tensor & vec, Scalar beta, + } + #else + static auto table = globalATenDispatch().getOpTable("aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mat, vec, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat, vec))(const_cast(*this), mat, vec, beta, alpha); + #endif + } + inline Tensor & Tensor::addmv_(const Tensor & mat, const Tensor & vec, Scalar beta, Scalar alpha) const { +@@ -250,7 +272,7 @@ inline Tensor & Tensor::addmv_(const Tensor & mat, const Tensor & vec, Scalar be + } + #else + static auto table = globalATenDispatch().getOpTable("aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), mat, vec, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat, vec))(const_cast(*this), mat, vec, beta, alpha); + #endif + } + inline Tensor Tensor::addr(const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha) const { +@@ -258,7 +280,7 @@ inline Tensor Tensor::addr(const Tensor & vec1, const Tensor & vec2, Scalar beta + return TypeDefault::addr(const_cast(*this), vec1, vec2, beta, alpha); + #else + static auto table = globalATenDispatch().getOpTable("aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), vec1, vec2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, vec1, vec2))(const_cast(*this), vec1, vec2, beta, alpha); + #endif + } + inline Tensor & Tensor::addr_(const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha) const { +@@ -266,7 +288,7 @@ inline Tensor & Tensor::addr_(const Tensor & vec1, const Tensor & vec2, Scalar b + return TypeDefault::addr_(const_cast(*this), vec1, vec2, beta, alpha); + #else + static auto table = globalATenDispatch().getOpTable("aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), vec1, vec2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, vec1, vec2))(const_cast(*this), vec1, vec2, beta, alpha); + #endif + } + inline Tensor Tensor::all(int64_t dim, bool keepdim) const { +@@ -274,7 +296,7 @@ inline Tensor Tensor::all(int64_t dim, bool keepdim) const { + return TypeDefault::all(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + inline bool Tensor::allclose(const Tensor & other, double rtol, double atol, bool equal_nan) const { +@@ -282,7 +304,7 @@ inline bool Tensor::allclose(const Tensor & other, double rtol, double atol, boo + return TypeDefault::allclose(const_cast(*this), other, rtol, atol, equal_nan); + #else + static auto table = globalATenDispatch().getOpTable("aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool"); +- return table->getOp(type_set())(const_cast(*this), other, rtol, atol, equal_nan); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, rtol, atol, equal_nan); + #endif + } + inline Tensor Tensor::any(int64_t dim, bool keepdim) const { +@@ -290,7 +312,7 @@ inline Tensor Tensor::any(int64_t dim, bool keepdim) const { + return TypeDefault::any(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + inline Tensor Tensor::argmax(c10::optional dim, bool keepdim) const { +@@ -298,7 +320,7 @@ inline Tensor Tensor::argmax(c10::optional dim, bool keepdim) const { + return TypeDefault::argmax(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"); +- return table->getOp, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + inline Tensor Tensor::argmin(c10::optional dim, bool keepdim) const { +@@ -306,7 +328,7 @@ inline Tensor Tensor::argmin(c10::optional dim, bool keepdim) const { + return TypeDefault::argmin(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"); +- return table->getOp, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) const { +@@ -323,7 +345,7 @@ inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::opti + } + #else + static auto table = globalATenDispatch().getOpTable("aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)"); +- return table->getOp)>(type_set())(const_cast(*this), size, stride, storage_offset); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, stride, storage_offset); + #endif + } + inline Tensor & Tensor::as_strided_(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) const { +@@ -331,7 +353,7 @@ inline Tensor & Tensor::as_strided_(IntArrayRef size, IntArrayRef stride, c10::o + return TypeDefault::as_strided_(const_cast(*this), size, stride, storage_offset); + #else + static auto table = globalATenDispatch().getOpTable("aten::as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!)"); +- return table->getOp)>(type_set())(const_cast(*this), size, stride, storage_offset); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, stride, storage_offset); + #endif + } + inline Tensor Tensor::asin() const { +@@ -339,7 +361,7 @@ inline Tensor Tensor::asin() const { + return TypeDefault::asin(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::asin(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::asin_() const { +@@ -353,7 +375,7 @@ inline Tensor & Tensor::asin_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::asin_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::atan() const { +@@ -361,7 +383,7 @@ inline Tensor Tensor::atan() const { + return TypeDefault::atan(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::atan(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::atan_() const { +@@ -375,7 +397,7 @@ inline Tensor & Tensor::atan_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::atan_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::baddbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { +@@ -389,7 +411,7 @@ inline Tensor Tensor::baddbmm(const Tensor & batch1, const Tensor & batch2, Scal + } + #else + static auto table = globalATenDispatch().getOpTable("aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), batch1, batch2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))(const_cast(*this), batch1, batch2, beta, alpha); + #endif + } + inline Tensor & Tensor::baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { +@@ -403,7 +425,7 @@ inline Tensor & Tensor::baddbmm_(const Tensor & batch1, const Tensor & batch2, S + } + #else + static auto table = globalATenDispatch().getOpTable("aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), batch1, batch2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))(const_cast(*this), batch1, batch2, beta, alpha); + #endif + } + inline Tensor Tensor::bernoulli(Generator * generator) const { +@@ -411,7 +433,7 @@ inline Tensor Tensor::bernoulli(Generator * generator) const { + return TypeDefault::bernoulli(const_cast(*this), generator); + #else + static auto table = globalATenDispatch().getOpTable("aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), generator); + #endif + } + inline Tensor & Tensor::bernoulli_(const Tensor & p, Generator * generator) const { +@@ -425,7 +447,7 @@ inline Tensor & Tensor::bernoulli_(const Tensor & p, Generator * generator) cons + } + #else + static auto table = globalATenDispatch().getOpTable("aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), p, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, p))(const_cast(*this), p, generator); + #endif + } + inline Tensor & Tensor::bernoulli_(double p, Generator * generator) const { +@@ -439,7 +461,7 @@ inline Tensor & Tensor::bernoulli_(double p, Generator * generator) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), p, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); + #endif + } + inline Tensor Tensor::bernoulli(double p, Generator * generator) const { +@@ -447,7 +469,7 @@ inline Tensor Tensor::bernoulli(double p, Generator * generator) const { + return TypeDefault::bernoulli(const_cast(*this), p, generator); + #else + static auto table = globalATenDispatch().getOpTable("aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), p, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); + #endif + } + inline Tensor Tensor::bincount(const Tensor & weights, int64_t minlength) const { +@@ -461,7 +483,7 @@ inline Tensor Tensor::bincount(const Tensor & weights, int64_t minlength) const + } + #else + static auto table = globalATenDispatch().getOpTable("aten::bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), weights, minlength); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, weights))(const_cast(*this), weights, minlength); + #endif + } + inline Tensor Tensor::bitwise_not() const { +@@ -469,7 +491,7 @@ inline Tensor Tensor::bitwise_not() const { + return TypeDefault::bitwise_not(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::bitwise_not(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::bitwise_not_() const { +@@ -477,7 +499,7 @@ inline Tensor & Tensor::bitwise_not_() const { + return TypeDefault::bitwise_not_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::logical_not() const { +@@ -485,7 +507,7 @@ inline Tensor Tensor::logical_not() const { + return TypeDefault::logical_not(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::logical_not(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::logical_not_() const { +@@ -493,7 +515,7 @@ inline Tensor & Tensor::logical_not_() const { + return TypeDefault::logical_not_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::logical_not_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::logical_xor(const Tensor & other) const { +@@ -501,7 +523,7 @@ inline Tensor Tensor::logical_xor(const Tensor & other) const { + return TypeDefault::logical_xor(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::logical_xor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::logical_xor_(const Tensor & other) const { +@@ -509,7 +531,7 @@ inline Tensor & Tensor::logical_xor_(const Tensor & other) const { + return TypeDefault::logical_xor_(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::bmm(const Tensor & mat2) const { +@@ -523,7 +545,7 @@ inline Tensor Tensor::bmm(const Tensor & mat2) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::bmm(Tensor self, Tensor mat2) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mat2); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat2))(const_cast(*this), mat2); + #endif + } + inline Tensor Tensor::ceil() const { +@@ -531,7 +553,7 @@ inline Tensor Tensor::ceil() const { + return TypeDefault::ceil(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::ceil(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::ceil_() const { +@@ -539,7 +561,7 @@ inline Tensor & Tensor::ceil_() const { + return TypeDefault::ceil_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::ceil_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline std::vector Tensor::chunk(int64_t chunks, int64_t dim) const { +@@ -547,7 +569,7 @@ inline std::vector Tensor::chunk(int64_t chunks, int64_t dim) const { + return TypeDefault::chunk(const_cast(*this), chunks, dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::chunk(Tensor(a) self, int chunks, int dim=0) -> Tensor(a)[]"); +- return table->getOp (const Tensor &, int64_t, int64_t)>(type_set())(const_cast(*this), chunks, dim); ++ return table->getOp (const Tensor &, int64_t, int64_t)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), chunks, dim); + #endif + } + inline Tensor Tensor::clamp(c10::optional min, c10::optional max) const { +@@ -555,7 +577,7 @@ inline Tensor Tensor::clamp(c10::optional min, c10::optional max + return TypeDefault::clamp(const_cast(*this), min, max); + #else + static auto table = globalATenDispatch().getOpTable("aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"); +- return table->getOp, c10::optional)>(type_set())(const_cast(*this), min, max); ++ return table->getOp, c10::optional)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), min, max); + #endif + } + inline Tensor & Tensor::clamp_(c10::optional min, c10::optional max) const { +@@ -569,7 +591,7 @@ inline Tensor & Tensor::clamp_(c10::optional min, c10::optional + } + #else + static auto table = globalATenDispatch().getOpTable("aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)"); +- return table->getOp, c10::optional)>(type_set())(const_cast(*this), min, max); ++ return table->getOp, c10::optional)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), min, max); + #endif + } + inline Tensor Tensor::clamp_max(Scalar max) const { +@@ -577,7 +599,7 @@ inline Tensor Tensor::clamp_max(Scalar max) const { + return TypeDefault::clamp_max(const_cast(*this), max); + #else + static auto table = globalATenDispatch().getOpTable("aten::clamp_max(Tensor self, Scalar max) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), max); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), max); + #endif + } + inline Tensor & Tensor::clamp_max_(Scalar max) const { +@@ -591,7 +613,7 @@ inline Tensor & Tensor::clamp_max_(Scalar max) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), max); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), max); + #endif + } + inline Tensor Tensor::clamp_min(Scalar min) const { +@@ -599,7 +621,7 @@ inline Tensor Tensor::clamp_min(Scalar min) const { + return TypeDefault::clamp_min(const_cast(*this), min); + #else + static auto table = globalATenDispatch().getOpTable("aten::clamp_min(Tensor self, Scalar min) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), min); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), min); + #endif + } + inline Tensor & Tensor::clamp_min_(Scalar min) const { +@@ -613,7 +635,7 @@ inline Tensor & Tensor::clamp_min_(Scalar min) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), min); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), min); + #endif + } + inline Tensor Tensor::contiguous(MemoryFormat memory_format) const { +@@ -621,7 +643,7 @@ inline Tensor Tensor::contiguous(MemoryFormat memory_format) const { + return TypeDefault::contiguous(const_cast(*this), memory_format); + #else + static auto table = globalATenDispatch().getOpTable("aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), memory_format); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), memory_format); + #endif + } + inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { +@@ -629,7 +651,7 @@ inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { + return TypeDefault::copy_(const_cast(*this), src, non_blocking); + #else + static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), src, non_blocking); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast(*this), src, non_blocking); + #endif + } + inline Tensor Tensor::cos() const { +@@ -637,7 +659,7 @@ inline Tensor Tensor::cos() const { + return TypeDefault::cos(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::cos(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::cos_() const { +@@ -651,7 +673,7 @@ inline Tensor & Tensor::cos_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::cos_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::cosh() const { +@@ -659,7 +681,7 @@ inline Tensor Tensor::cosh() const { + return TypeDefault::cosh(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::cosh(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::cosh_() const { +@@ -673,7 +695,7 @@ inline Tensor & Tensor::cosh_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::cosh_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::cumsum(int64_t dim, c10::optional dtype) const { +@@ -681,7 +703,7 @@ inline Tensor Tensor::cumsum(int64_t dim, c10::optional dtype) const + return TypeDefault::cumsum(const_cast(*this), dim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); + #endif + } + inline Tensor Tensor::cumprod(int64_t dim, c10::optional dtype) const { +@@ -689,7 +711,7 @@ inline Tensor Tensor::cumprod(int64_t dim, c10::optional dtype) cons + return TypeDefault::cumprod(const_cast(*this), dim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); + #endif + } + inline Tensor Tensor::det() const { +@@ -697,7 +719,7 @@ inline Tensor Tensor::det() const { + return TypeDefault::det(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::det(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) const { +@@ -705,7 +727,7 @@ inline Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) con + return TypeDefault::diag_embed(const_cast(*this), offset, dim1, dim2); + #else + static auto table = globalATenDispatch().getOpTable("aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), offset, dim1, dim2); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), offset, dim1, dim2); + #endif + } + inline Tensor Tensor::diagflat(int64_t offset) const { +@@ -713,7 +735,7 @@ inline Tensor Tensor::diagflat(int64_t offset) const { + return TypeDefault::diagflat(const_cast(*this), offset); + #else + static auto table = globalATenDispatch().getOpTable("aten::diagflat(Tensor self, int offset=0) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), offset); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), offset); + #endif + } + inline Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const { +@@ -721,7 +743,7 @@ inline Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const + return TypeDefault::diagonal(const_cast(*this), offset, dim1, dim2); + #else + static auto table = globalATenDispatch().getOpTable("aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), offset, dim1, dim2); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), offset, dim1, dim2); + #endif + } + inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) const { +@@ -729,23 +751,41 @@ inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) const { + return TypeDefault::fill_diagonal_(const_cast(*this), fill_value, wrap); + #else + static auto table = globalATenDispatch().getOpTable("aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), fill_value, wrap); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), fill_value, wrap); + #endif + } + inline Tensor Tensor::div(const Tensor & other) const { + #ifdef USE_STATIC_DISPATCH +- return TypeDefault::div(const_cast(*this), other); ++ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { ++ case Backend::CPU: ++ return CPUType::div(const_cast(*this), other); ++ break; ++ case Backend::SparseCPU: ++ return SparseCPUType::div(const_cast(*this), other); ++ break; ++ default: ++ AT_ERROR("div not implemented for ", at::toString(type_set())); ++ } + #else + static auto table = globalATenDispatch().getOpTable("aten::div.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::div_(const Tensor & other) const { + #ifdef USE_STATIC_DISPATCH +- return TypeDefault::div_(const_cast(*this), other); ++ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { ++ case Backend::CPU: ++ return CPUType::div_(const_cast(*this), other); ++ break; ++ case Backend::SparseCPU: ++ return SparseCPUType::div_(const_cast(*this), other); ++ break; ++ default: ++ AT_ERROR("div_ not implemented for ", at::toString(type_set())); ++ } + #else + static auto table = globalATenDispatch().getOpTable("aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::div(Scalar other) const { +@@ -753,7 +793,7 @@ inline Tensor Tensor::div(Scalar other) const { + return TypeDefault::div(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::div.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::div_(Scalar other) const { +@@ -761,7 +801,7 @@ inline Tensor & Tensor::div_(Scalar other) const { + return TypeDefault::div_(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::dot(const Tensor & tensor) const { +@@ -775,7 +815,7 @@ inline Tensor Tensor::dot(const Tensor & tensor) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::dot(Tensor self, Tensor tensor) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), tensor); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor))(const_cast(*this), tensor); + #endif + } + inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options) const { +@@ -783,7 +823,7 @@ inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options) + return TypeDefault::new_empty(const_cast(*this), size, options); + #else + static auto table = globalATenDispatch().getOpTable("aten::new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), size, options); ++ return table->getOp(type_set(/* HMMMM */))(const_cast(*this), size, options); + #endif + } + inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options) const { +@@ -791,7 +831,7 @@ inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const Tensor + return TypeDefault::new_full(const_cast(*this), size, fill_value, options); + #else + static auto table = globalATenDispatch().getOpTable("aten::new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), size, fill_value, options); ++ return table->getOp(type_set(/* HMMMM */))(const_cast(*this), size, fill_value, options); + #endif + } + inline Tensor & Tensor::resize_(IntArrayRef size) const { +@@ -805,7 +845,7 @@ inline Tensor & Tensor::resize_(IntArrayRef size) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::resize_(Tensor(a!) self, int[] size) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), size); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size); + #endif + } + inline Tensor Tensor::erf() const { +@@ -813,7 +853,7 @@ inline Tensor Tensor::erf() const { + return TypeDefault::erf(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::erf(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::erf_() const { +@@ -827,7 +867,7 @@ inline Tensor & Tensor::erf_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::erf_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::erfc() const { +@@ -835,7 +875,7 @@ inline Tensor Tensor::erfc() const { + return TypeDefault::erfc(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::erfc(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::erfc_() const { +@@ -849,7 +889,7 @@ inline Tensor & Tensor::erfc_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::erfc_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::exp() const { +@@ -857,7 +897,7 @@ inline Tensor Tensor::exp() const { + return TypeDefault::exp(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::exp(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::exp_() const { +@@ -871,7 +911,7 @@ inline Tensor & Tensor::exp_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::exp_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::expm1() const { +@@ -879,7 +919,7 @@ inline Tensor Tensor::expm1() const { + return TypeDefault::expm1(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::expm1(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::expm1_() const { +@@ -893,7 +933,7 @@ inline Tensor & Tensor::expm1_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::expm1_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::expand(IntArrayRef size, bool implicit) const { +@@ -901,7 +941,7 @@ inline Tensor Tensor::expand(IntArrayRef size, bool implicit) const { + return TypeDefault::expand(const_cast(*this), size, implicit); + #else + static auto table = globalATenDispatch().getOpTable("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), size, implicit); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, implicit); + #endif + } + inline Tensor Tensor::expand_as(const Tensor & other) const { +@@ -909,7 +949,7 @@ inline Tensor Tensor::expand_as(const Tensor & other) const { + return TypeDefault::expand_as(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::expand_as(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const { +@@ -917,7 +957,7 @@ inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const { + return TypeDefault::flatten(const_cast(*this), start_dim, end_dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), start_dim, end_dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), start_dim, end_dim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -926,7 +966,7 @@ inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim, Dimname out_di + return TypeDefault::flatten(const_cast(*this), start_dim, end_dim, out_dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), start_dim, end_dim, out_dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), start_dim, end_dim, out_dim); + #endif + } + #endif +@@ -936,7 +976,7 @@ inline Tensor Tensor::flatten(Dimname start_dim, Dimname end_dim, Dimname out_di + return TypeDefault::flatten(const_cast(*this), start_dim, end_dim, out_dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), start_dim, end_dim, out_dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), start_dim, end_dim, out_dim); + #endif + } + #endif +@@ -946,7 +986,7 @@ inline Tensor Tensor::flatten(DimnameList dims, Dimname out_dim) const { + return TypeDefault::flatten(const_cast(*this), dims, out_dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dims, out_dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dims, out_dim); + #endif + } + #endif +@@ -955,7 +995,7 @@ inline Tensor & Tensor::fill_(Scalar value) const { + return TypeDefault::fill_(const_cast(*this), value); + #else + static auto table = globalATenDispatch().getOpTable("aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), value); + #endif + } + inline Tensor & Tensor::fill_(const Tensor & value) const { +@@ -963,7 +1003,7 @@ inline Tensor & Tensor::fill_(const Tensor & value) const { + return TypeDefault::fill_(const_cast(*this), value); + #else + static auto table = globalATenDispatch().getOpTable("aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, value))(const_cast(*this), value); + #endif + } + inline Tensor Tensor::floor() const { +@@ -971,7 +1011,7 @@ inline Tensor Tensor::floor() const { + return TypeDefault::floor(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::floor(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::floor_() const { +@@ -985,7 +1025,7 @@ inline Tensor & Tensor::floor_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::floor_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::frac() const { +@@ -993,7 +1033,7 @@ inline Tensor Tensor::frac() const { + return TypeDefault::frac(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::frac(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::frac_() const { +@@ -1007,7 +1047,7 @@ inline Tensor & Tensor::frac_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::frac_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::ger(const Tensor & vec2) const { +@@ -1021,7 +1061,7 @@ inline Tensor Tensor::ger(const Tensor & vec2) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ger(Tensor self, Tensor vec2) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), vec2); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, vec2))(const_cast(*this), vec2); + #endif + } + inline Tensor Tensor::fft(int64_t signal_ndim, bool normalized) const { +@@ -1029,7 +1069,7 @@ inline Tensor Tensor::fft(int64_t signal_ndim, bool normalized) const { + return TypeDefault::fft(const_cast(*this), signal_ndim, normalized); + #else + static auto table = globalATenDispatch().getOpTable("aten::fft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), signal_ndim, normalized); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), signal_ndim, normalized); + #endif + } + inline Tensor Tensor::ifft(int64_t signal_ndim, bool normalized) const { +@@ -1037,7 +1077,7 @@ inline Tensor Tensor::ifft(int64_t signal_ndim, bool normalized) const { + return TypeDefault::ifft(const_cast(*this), signal_ndim, normalized); + #else + static auto table = globalATenDispatch().getOpTable("aten::ifft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), signal_ndim, normalized); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), signal_ndim, normalized); + #endif + } + inline Tensor Tensor::rfft(int64_t signal_ndim, bool normalized, bool onesided) const { +@@ -1045,7 +1085,7 @@ inline Tensor Tensor::rfft(int64_t signal_ndim, bool normalized, bool onesided) + return TypeDefault::rfft(const_cast(*this), signal_ndim, normalized, onesided); + #else + static auto table = globalATenDispatch().getOpTable("aten::rfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), signal_ndim, normalized, onesided); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), signal_ndim, normalized, onesided); + #endif + } + inline Tensor Tensor::irfft(int64_t signal_ndim, bool normalized, bool onesided, IntArrayRef signal_sizes) const { +@@ -1053,7 +1093,7 @@ inline Tensor Tensor::irfft(int64_t signal_ndim, bool normalized, bool onesided, + return TypeDefault::irfft(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); + #else + static auto table = globalATenDispatch().getOpTable("aten::irfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True, int[] signal_sizes=[]) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); + #endif + } + inline Tensor Tensor::index(TensorList indices) const { +@@ -1061,7 +1101,7 @@ inline Tensor Tensor::index(TensorList indices) const { + return TypeDefault::index(const_cast(*this), indices); + #else + static auto table = globalATenDispatch().getOpTable("aten::index(Tensor self, Tensor?[] indices) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), indices); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices))(const_cast(*this), indices); + #endif + } + inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) const { +@@ -1069,7 +1109,7 @@ inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Ten + return TypeDefault::index_copy_(const_cast(*this), dim, index, source); + #else + static auto table = globalATenDispatch().getOpTable("aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim, index, source); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), dim, index, source); + #endif + } + inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const { +@@ -1077,7 +1117,7 @@ inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor + return TypeDefault::index_copy(const_cast(*this), dim, index, source); + #else + static auto table = globalATenDispatch().getOpTable("aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index, source); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), dim, index, source); + #endif + } + inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) const { +@@ -1085,7 +1125,7 @@ inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bo + return TypeDefault::index_put_(const_cast(*this), indices, values, accumulate); + #else + static auto table = globalATenDispatch().getOpTable("aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), indices, values, accumulate); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices, values))(const_cast(*this), indices, values, accumulate); + #endif + } + inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const { +@@ -1093,7 +1133,7 @@ inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool + return TypeDefault::index_put(const_cast(*this), indices, values, accumulate); + #else + static auto table = globalATenDispatch().getOpTable("aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), indices, values, accumulate); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices, values))(const_cast(*this), indices, values, accumulate); + #endif + } + inline Tensor Tensor::inverse() const { +@@ -1101,7 +1141,7 @@ inline Tensor Tensor::inverse() const { + return TypeDefault::inverse(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::inverse(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::isclose(const Tensor & other, double rtol, double atol, bool equal_nan) const { +@@ -1109,7 +1149,7 @@ inline Tensor Tensor::isclose(const Tensor & other, double rtol, double atol, bo + return TypeDefault::isclose(const_cast(*this), other, rtol, atol, equal_nan); + #else + static auto table = globalATenDispatch().getOpTable("aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other, rtol, atol, equal_nan); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, rtol, atol, equal_nan); + #endif + } + inline bool Tensor::is_distributed() const { +@@ -1117,7 +1157,7 @@ inline bool Tensor::is_distributed() const { + return TypeDefault::is_distributed(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::is_distributed(Tensor self) -> bool"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline bool Tensor::is_floating_point() const { +@@ -1125,7 +1165,7 @@ inline bool Tensor::is_floating_point() const { + return TypeDefault::is_floating_point(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::is_floating_point(Tensor self) -> bool"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline bool Tensor::is_complex() const { +@@ -1133,7 +1173,7 @@ inline bool Tensor::is_complex() const { + return TypeDefault::is_complex(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::is_complex(Tensor self) -> bool"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline bool Tensor::is_nonzero() const { +@@ -1141,7 +1181,7 @@ inline bool Tensor::is_nonzero() const { + return TypeDefault::is_nonzero(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::is_nonzero(Tensor self) -> bool"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline bool Tensor::is_same_size(const Tensor & other) const { +@@ -1149,7 +1189,7 @@ inline bool Tensor::is_same_size(const Tensor & other) const { + return TypeDefault::is_same_size(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::is_same_size(Tensor self, Tensor other) -> bool"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline bool Tensor::is_signed() const { +@@ -1157,7 +1197,7 @@ inline bool Tensor::is_signed() const { + return TypeDefault::is_signed(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::is_signed(Tensor self) -> bool"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool keepdim) const { +@@ -1165,7 +1205,7 @@ inline std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool k + return TypeDefault::kthvalue(const_cast(*this), k, dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, int64_t, int64_t, bool)>(type_set())(const_cast(*this), k, dim, keepdim); ++ return table->getOp (const Tensor &, int64_t, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), k, dim, keepdim); + #endif + } + inline Tensor Tensor::log() const { +@@ -1173,7 +1213,7 @@ inline Tensor Tensor::log() const { + return TypeDefault::log(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::log(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::log_() const { +@@ -1187,7 +1227,7 @@ inline Tensor & Tensor::log_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::log_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::log10() const { +@@ -1195,7 +1235,7 @@ inline Tensor Tensor::log10() const { + return TypeDefault::log10(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::log10(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::log10_() const { +@@ -1209,7 +1249,7 @@ inline Tensor & Tensor::log10_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::log10_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::log1p() const { +@@ -1217,7 +1257,7 @@ inline Tensor Tensor::log1p() const { + return TypeDefault::log1p(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::log1p(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::log1p_() const { +@@ -1234,7 +1274,7 @@ inline Tensor & Tensor::log1p_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::log1p_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::log2() const { +@@ -1242,7 +1282,7 @@ inline Tensor Tensor::log2() const { + return TypeDefault::log2(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::log2(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::log2_() const { +@@ -1256,7 +1296,7 @@ inline Tensor & Tensor::log2_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::log2_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::logdet() const { +@@ -1264,7 +1304,7 @@ inline Tensor Tensor::logdet() const { + return TypeDefault::logdet(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::logdet(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::log_softmax(int64_t dim, c10::optional dtype) const { +@@ -1272,7 +1312,7 @@ inline Tensor Tensor::log_softmax(int64_t dim, c10::optional dtype) + return TypeDefault::log_softmax(const_cast(*this), dim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::log_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1281,7 +1321,7 @@ inline Tensor Tensor::log_softmax(Dimname dim, c10::optional dtype) + return TypeDefault::log_softmax(const_cast(*this), dim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::log_softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); + #endif + } + #endif +@@ -1290,7 +1330,7 @@ inline Tensor Tensor::logsumexp(IntArrayRef dim, bool keepdim) const { + return TypeDefault::logsumexp(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1299,7 +1339,7 @@ inline Tensor Tensor::logsumexp(DimnameList dim, bool keepdim) const { + return TypeDefault::logsumexp(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #endif +@@ -1308,7 +1348,7 @@ inline Tensor Tensor::matmul(const Tensor & other) const { + return TypeDefault::matmul(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::matmul(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::matrix_power(int64_t n) const { +@@ -1316,7 +1356,7 @@ inline Tensor Tensor::matrix_power(int64_t n) const { + return TypeDefault::matrix_power(const_cast(*this), n); + #else + static auto table = globalATenDispatch().getOpTable("aten::matrix_power(Tensor self, int n) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), n); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), n); + #endif + } + inline std::tuple Tensor::max(int64_t dim, bool keepdim) const { +@@ -1324,7 +1364,7 @@ inline std::tuple Tensor::max(int64_t dim, bool keepdim) const { + return TypeDefault::max(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + inline Tensor Tensor::max_values(IntArrayRef dim, bool keepdim) const { +@@ -1332,7 +1372,7 @@ inline Tensor Tensor::max_values(IntArrayRef dim, bool keepdim) const { + return TypeDefault::max_values(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::max_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1341,7 +1381,7 @@ inline std::tuple Tensor::max(Dimname dim, bool keepdim) const { + return TypeDefault::max(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, Dimname, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #endif +@@ -1351,7 +1391,7 @@ inline Tensor Tensor::max_values(DimnameList dim, bool keepdim) const { + return TypeDefault::max_values(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::max_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #endif +@@ -1360,7 +1400,7 @@ inline Tensor Tensor::mean(c10::optional dtype) const { + return TypeDefault::mean(const_cast(*this), dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); + #endif + } + inline Tensor Tensor::mean(IntArrayRef dim, bool keepdim, c10::optional dtype) const { +@@ -1368,7 +1408,7 @@ inline Tensor Tensor::mean(IntArrayRef dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1377,7 +1417,7 @@ inline Tensor Tensor::mean(DimnameList dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); + #endif + } + #endif +@@ -1386,7 +1426,7 @@ inline std::tuple Tensor::median(int64_t dim, bool keepdim) const + return TypeDefault::median(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1395,7 +1435,7 @@ inline std::tuple Tensor::median(Dimname dim, bool keepdim) const + return TypeDefault::median(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, Dimname, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #endif +@@ -1404,7 +1444,7 @@ inline std::tuple Tensor::min(int64_t dim, bool keepdim) const { + return TypeDefault::min(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + inline Tensor Tensor::min_values(IntArrayRef dim, bool keepdim) const { +@@ -1412,7 +1452,7 @@ inline Tensor Tensor::min_values(IntArrayRef dim, bool keepdim) const { + return TypeDefault::min_values(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::min_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1421,7 +1461,7 @@ inline std::tuple Tensor::min(Dimname dim, bool keepdim) const { + return TypeDefault::min(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, Dimname, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #endif +@@ -1431,7 +1471,7 @@ inline Tensor Tensor::min_values(DimnameList dim, bool keepdim) const { + return TypeDefault::min_values(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::min_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + #endif +@@ -1449,7 +1489,7 @@ inline Tensor Tensor::mm(const Tensor & mat2) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::mm(Tensor self, Tensor mat2) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mat2); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat2))(const_cast(*this), mat2); + #endif + } + inline std::tuple Tensor::mode(int64_t dim, bool keepdim) const { +@@ -1457,7 +1497,7 @@ inline std::tuple Tensor::mode(int64_t dim, bool keepdim) const { + return TypeDefault::mode(const_cast(*this), dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, keepdim); ++ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); + #endif + } + inline Tensor Tensor::mul(const Tensor & other) const { +@@ -1474,7 +1514,7 @@ inline Tensor Tensor::mul(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::mul_(const Tensor & other) const { +@@ -1491,7 +1531,7 @@ inline Tensor & Tensor::mul_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::mul(Scalar other) const { +@@ -1499,7 +1539,7 @@ inline Tensor Tensor::mul(Scalar other) const { + return TypeDefault::mul(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::mul_(Scalar other) const { +@@ -1507,7 +1547,7 @@ inline Tensor & Tensor::mul_(Scalar other) const { + return TypeDefault::mul_(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::mv(const Tensor & vec) const { +@@ -1521,7 +1561,7 @@ inline Tensor Tensor::mv(const Tensor & vec) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::mv(Tensor self, Tensor vec) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), vec); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, vec))(const_cast(*this), vec); + #endif + } + inline Tensor Tensor::mvlgamma(int64_t p) const { +@@ -1529,7 +1569,7 @@ inline Tensor Tensor::mvlgamma(int64_t p) const { + return TypeDefault::mvlgamma(const_cast(*this), p); + #else + static auto table = globalATenDispatch().getOpTable("aten::mvlgamma(Tensor self, int p) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), p); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p); + #endif + } + inline Tensor & Tensor::mvlgamma_(int64_t p) const { +@@ -1537,7 +1577,7 @@ inline Tensor & Tensor::mvlgamma_(int64_t p) const { + return TypeDefault::mvlgamma_(const_cast(*this), p); + #else + static auto table = globalATenDispatch().getOpTable("aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), p); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p); + #endif + } + inline Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) const { +@@ -1554,7 +1594,7 @@ inline Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) co + } + #else + static auto table = globalATenDispatch().getOpTable("aten::narrow_copy(Tensor self, int dim, int start, int length) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, start, length); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, start, length); + #endif + } + inline Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const { +@@ -1562,7 +1602,7 @@ inline Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const { + return TypeDefault::narrow(const_cast(*this), dim, start, length); + #else + static auto table = globalATenDispatch().getOpTable("aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dim, start, length); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, start, length); + #endif + } + inline Tensor Tensor::permute(IntArrayRef dims) const { +@@ -1570,7 +1610,7 @@ inline Tensor Tensor::permute(IntArrayRef dims) const { + return TypeDefault::permute(const_cast(*this), dims); + #else + static auto table = globalATenDispatch().getOpTable("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dims); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dims); + #endif + } + inline Tensor Tensor::numpy_T() const { +@@ -1578,7 +1618,7 @@ inline Tensor Tensor::numpy_T() const { + return TypeDefault::numpy_T(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::numpy_T(Tensor(a) self) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline bool Tensor::is_pinned() const { +@@ -1586,7 +1626,7 @@ inline bool Tensor::is_pinned() const { + return TypeDefault::is_pinned(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::is_pinned(Tensor self) -> bool"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::pin_memory() const { +@@ -1594,7 +1634,7 @@ inline Tensor Tensor::pin_memory() const { + return TypeDefault::pin_memory(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::pin_memory(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::pinverse(double rcond) const { +@@ -1602,7 +1642,7 @@ inline Tensor Tensor::pinverse(double rcond) const { + return TypeDefault::pinverse(const_cast(*this), rcond); + #else + static auto table = globalATenDispatch().getOpTable("aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), rcond); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), rcond); + #endif + } + inline Tensor Tensor::reciprocal() const { +@@ -1610,7 +1650,7 @@ inline Tensor Tensor::reciprocal() const { + return TypeDefault::reciprocal(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::reciprocal(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::reciprocal_() const { +@@ -1624,7 +1664,7 @@ inline Tensor & Tensor::reciprocal_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::neg() const { +@@ -1632,7 +1672,7 @@ inline Tensor Tensor::neg() const { + return TypeDefault::neg(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::neg(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::neg_() const { +@@ -1640,7 +1680,7 @@ inline Tensor & Tensor::neg_() const { + return TypeDefault::neg_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::neg_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::repeat(IntArrayRef repeats) const { +@@ -1648,7 +1688,7 @@ inline Tensor Tensor::repeat(IntArrayRef repeats) const { + return TypeDefault::repeat(const_cast(*this), repeats); + #else + static auto table = globalATenDispatch().getOpTable("aten::repeat(Tensor self, int[] repeats) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), repeats); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), repeats); + #endif + } + inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional dim) const { +@@ -1656,7 +1696,7 @@ inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional(*this), repeats, dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), repeats, dim); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this, repeats))(const_cast(*this), repeats, dim); + #endif + } + inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional dim) const { +@@ -1664,7 +1704,7 @@ inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional + return TypeDefault::repeat_interleave(const_cast(*this), repeats, dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), repeats, dim); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), repeats, dim); + #endif + } + inline Tensor Tensor::reshape(IntArrayRef shape) const { +@@ -1672,7 +1712,7 @@ inline Tensor Tensor::reshape(IntArrayRef shape) const { + return TypeDefault::reshape(const_cast(*this), shape); + #else + static auto table = globalATenDispatch().getOpTable("aten::reshape(Tensor self, int[] shape) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), shape); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), shape); + #endif + } + inline Tensor Tensor::reshape_as(const Tensor & other) const { +@@ -1680,7 +1720,7 @@ inline Tensor Tensor::reshape_as(const Tensor & other) const { + return TypeDefault::reshape_as(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::reshape_as(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::round() const { +@@ -1688,7 +1728,7 @@ inline Tensor Tensor::round() const { + return TypeDefault::round(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::round(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::round_() const { +@@ -1702,7 +1742,7 @@ inline Tensor & Tensor::round_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::round_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::relu() const { +@@ -1719,7 +1759,7 @@ inline Tensor Tensor::relu() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::relu(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::relu_() const { +@@ -1736,7 +1776,7 @@ inline Tensor & Tensor::relu_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::relu_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::prelu(const Tensor & weight) const { +@@ -1750,7 +1790,7 @@ inline Tensor Tensor::prelu(const Tensor & weight) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::prelu(Tensor self, Tensor weight) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), weight); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, weight))(const_cast(*this), weight); + #endif + } + inline std::tuple Tensor::prelu_backward(const Tensor & grad_output, const Tensor & weight) const { +@@ -1764,7 +1804,7 @@ inline std::tuple Tensor::prelu_backward(const Tensor & grad_outp + } + #else + static auto table = globalATenDispatch().getOpTable("aten::prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)"); +- return table->getOp (const Tensor &, const Tensor &, const Tensor &)>(type_set())(grad_output, const_cast(*this), weight); ++ return table->getOp (const Tensor &, const Tensor &, const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(grad_output, *this, weight))(grad_output, const_cast(*this), weight); + #endif + } + inline Tensor Tensor::hardshrink(Scalar lambd) const { +@@ -1778,7 +1818,7 @@ inline Tensor Tensor::hardshrink(Scalar lambd) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), lambd); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), lambd); + #endif + } + inline Tensor Tensor::hardshrink_backward(const Tensor & grad_out, Scalar lambd) const { +@@ -1792,7 +1832,7 @@ inline Tensor Tensor::hardshrink_backward(const Tensor & grad_out, Scalar lambd) + } + #else + static auto table = globalATenDispatch().getOpTable("aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor"); +- return table->getOp(type_set())(grad_out, const_cast(*this), lambd); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(grad_out, *this))(grad_out, const_cast(*this), lambd); + #endif + } + inline Tensor Tensor::rsqrt() const { +@@ -1800,7 +1840,7 @@ inline Tensor Tensor::rsqrt() const { + return TypeDefault::rsqrt(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::rsqrt(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::rsqrt_() const { +@@ -1814,7 +1854,7 @@ inline Tensor & Tensor::rsqrt_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1823,7 +1863,7 @@ inline Tensor Tensor::select(Dimname dim, int64_t index) const { + return TypeDefault::select(const_cast(*this), dim, index); + #else + static auto table = globalATenDispatch().getOpTable("aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dim, index); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, index); + #endif + } + #endif +@@ -1832,7 +1872,7 @@ inline Tensor Tensor::select(int64_t dim, int64_t index) const { + return TypeDefault::select(const_cast(*this), dim, index); + #else + static auto table = globalATenDispatch().getOpTable("aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dim, index); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, index); + #endif + } + inline Tensor Tensor::sigmoid() const { +@@ -1846,7 +1886,7 @@ inline Tensor Tensor::sigmoid() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sigmoid(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::sigmoid_() const { +@@ -1860,7 +1900,7 @@ inline Tensor & Tensor::sigmoid_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::sin() const { +@@ -1868,7 +1908,7 @@ inline Tensor Tensor::sin() const { + return TypeDefault::sin(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::sin(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::sin_() const { +@@ -1882,7 +1922,7 @@ inline Tensor & Tensor::sin_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sin_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::sinh() const { +@@ -1890,7 +1930,7 @@ inline Tensor Tensor::sinh() const { + return TypeDefault::sinh(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::sinh(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::sinh_() const { +@@ -1904,7 +1944,7 @@ inline Tensor & Tensor::sinh_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sinh_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::detach() const { +@@ -1912,7 +1952,7 @@ inline Tensor Tensor::detach() const { + return TypeDefault::detach(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::detach(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::detach_() const { +@@ -1920,7 +1960,7 @@ inline Tensor & Tensor::detach_() const { + return TypeDefault::detach_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::detach_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline int64_t Tensor::size(int64_t dim) const { +@@ -1928,7 +1968,7 @@ inline int64_t Tensor::size(int64_t dim) const { + return TypeDefault::size(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::size.int(Tensor self, int dim) -> int"); +- return table->getOp(type_set())(const_cast(*this), dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1937,7 +1977,7 @@ inline int64_t Tensor::size(Dimname dim) const { + return TypeDefault::size(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::size.Dimname(Tensor self, Dimname dim) -> int"); +- return table->getOp(type_set())(const_cast(*this), dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + #endif +@@ -1946,7 +1986,7 @@ inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t ste + return TypeDefault::slice(const_cast(*this), dim, start, end, step); + #else + static auto table = globalATenDispatch().getOpTable("aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dim, start, end, step); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, start, end, step); + #endif + } + inline std::tuple Tensor::slogdet() const { +@@ -1954,7 +1994,7 @@ inline std::tuple Tensor::slogdet() const { + return TypeDefault::slogdet(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)"); +- return table->getOp (const Tensor &)>(type_set())(const_cast(*this)); ++ return table->getOp (const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::smm(const Tensor & mat2) const { +@@ -1962,7 +2002,7 @@ inline Tensor Tensor::smm(const Tensor & mat2) const { + return TypeDefault::smm(const_cast(*this), mat2); + #else + static auto table = globalATenDispatch().getOpTable("aten::smm(Tensor self, Tensor mat2) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mat2); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat2))(const_cast(*this), mat2); + #endif + } + inline Tensor Tensor::softmax(int64_t dim, c10::optional dtype) const { +@@ -1970,7 +2010,7 @@ inline Tensor Tensor::softmax(int64_t dim, c10::optional dtype) cons + return TypeDefault::softmax(const_cast(*this), dim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -1979,7 +2019,7 @@ inline Tensor Tensor::softmax(Dimname dim, c10::optional dtype) cons + return TypeDefault::softmax(const_cast(*this), dim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); + #endif + } + #endif +@@ -1988,7 +2028,7 @@ inline std::vector Tensor::split(int64_t split_size, int64_t dim) const + return TypeDefault::split(const_cast(*this), split_size, dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[]"); +- return table->getOp (const Tensor &, int64_t, int64_t)>(type_set())(const_cast(*this), split_size, dim); ++ return table->getOp (const Tensor &, int64_t, int64_t)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), split_size, dim); + #endif + } + inline std::vector Tensor::split_with_sizes(IntArrayRef split_sizes, int64_t dim) const { +@@ -1996,7 +2036,7 @@ inline std::vector Tensor::split_with_sizes(IntArrayRef split_sizes, int + return TypeDefault::split_with_sizes(const_cast(*this), split_sizes, dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]"); +- return table->getOp (const Tensor &, IntArrayRef, int64_t)>(type_set())(const_cast(*this), split_sizes, dim); ++ return table->getOp (const Tensor &, IntArrayRef, int64_t)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), split_sizes, dim); + #endif + } + inline Tensor Tensor::squeeze() const { +@@ -2004,7 +2044,7 @@ inline Tensor Tensor::squeeze() const { + return TypeDefault::squeeze(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::squeeze(Tensor(a) self) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::squeeze(int64_t dim) const { +@@ -2012,7 +2052,7 @@ inline Tensor Tensor::squeeze(int64_t dim) const { + return TypeDefault::squeeze(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + inline Tensor & Tensor::squeeze_() const { +@@ -2020,7 +2060,7 @@ inline Tensor & Tensor::squeeze_() const { + return TypeDefault::squeeze_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::squeeze_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::squeeze_(int64_t dim) const { +@@ -2028,7 +2068,7 @@ inline Tensor & Tensor::squeeze_(int64_t dim) const { + return TypeDefault::squeeze_(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + inline Tensor Tensor::sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { +@@ -2036,7 +2076,7 @@ inline Tensor Tensor::sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar + return TypeDefault::sspaddmm(const_cast(*this), mat1, mat2, beta, alpha); + #else + static auto table = globalATenDispatch().getOpTable("aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mat1, mat2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))(const_cast(*this), mat1, mat2, beta, alpha); + #endif + } + inline Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10::optional win_length, const Tensor & window, bool normalized, bool onesided) const { +@@ -2044,7 +2084,7 @@ inline Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10 + return TypeDefault::stft(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); + #else + static auto table = globalATenDispatch().getOpTable("aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool onesided=True) -> Tensor"); +- return table->getOp, c10::optional, const Tensor &, bool, bool)>(type_set())(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); ++ return table->getOp, c10::optional, const Tensor &, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, window))(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); + #endif + } + inline int64_t Tensor::stride(int64_t dim) const { +@@ -2052,7 +2092,7 @@ inline int64_t Tensor::stride(int64_t dim) const { + return TypeDefault::stride(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::stride.int(Tensor self, int dim) -> int"); +- return table->getOp(type_set())(const_cast(*this), dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -2061,7 +2101,7 @@ inline int64_t Tensor::stride(Dimname dim) const { + return TypeDefault::stride(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::stride.Dimname(Tensor self, Dimname dim) -> int"); +- return table->getOp(type_set())(const_cast(*this), dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + #endif +@@ -2070,7 +2110,7 @@ inline Tensor Tensor::sum(c10::optional dtype) const { + return TypeDefault::sum(const_cast(*this), dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); + #endif + } + inline Tensor Tensor::sum(IntArrayRef dim, bool keepdim, c10::optional dtype) const { +@@ -2078,7 +2118,7 @@ inline Tensor Tensor::sum(IntArrayRef dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -2087,7 +2127,7 @@ inline Tensor Tensor::sum(DimnameList dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); + #endif + } + #endif +@@ -2096,7 +2136,7 @@ inline Tensor Tensor::sum_to_size(IntArrayRef size) const { + return TypeDefault::sum_to_size(const_cast(*this), size); + #else + static auto table = globalATenDispatch().getOpTable("aten::sum_to_size(Tensor self, int[] size) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), size); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size); + #endif + } + inline Tensor Tensor::sqrt() const { +@@ -2104,7 +2144,7 @@ inline Tensor Tensor::sqrt() const { + return TypeDefault::sqrt(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::sqrt(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::sqrt_() const { +@@ -2118,7 +2158,7 @@ inline Tensor & Tensor::sqrt_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sqrt_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::std(bool unbiased) const { +@@ -2126,7 +2166,7 @@ inline Tensor Tensor::std(bool unbiased) const { + return TypeDefault::std(const_cast(*this), unbiased); + #else + static auto table = globalATenDispatch().getOpTable("aten::std(Tensor self, bool unbiased=True) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), unbiased); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), unbiased); + #endif + } + inline Tensor Tensor::std(IntArrayRef dim, bool unbiased, bool keepdim) const { +@@ -2134,7 +2174,7 @@ inline Tensor Tensor::std(IntArrayRef dim, bool unbiased, bool keepdim) const { + return TypeDefault::std(const_cast(*this), dim, unbiased, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::std.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, unbiased, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -2143,7 +2183,7 @@ inline Tensor Tensor::std(DimnameList dim, bool unbiased, bool keepdim) const { + return TypeDefault::std(const_cast(*this), dim, unbiased, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, unbiased, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); + #endif + } + #endif +@@ -2152,7 +2192,7 @@ inline Tensor Tensor::prod(c10::optional dtype) const { + return TypeDefault::prod(const_cast(*this), dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); + #endif + } + inline Tensor Tensor::prod(int64_t dim, bool keepdim, c10::optional dtype) const { +@@ -2160,7 +2200,7 @@ inline Tensor Tensor::prod(int64_t dim, bool keepdim, c10::optional + return TypeDefault::prod(const_cast(*this), dim, keepdim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -2169,7 +2209,7 @@ inline Tensor Tensor::prod(Dimname dim, bool keepdim, c10::optional + return TypeDefault::prod(const_cast(*this), dim, keepdim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); + #endif + } + #endif +@@ -2178,7 +2218,7 @@ inline Tensor Tensor::t() const { + return TypeDefault::t(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::t(Tensor(a) self) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::t_() const { +@@ -2186,7 +2226,7 @@ inline Tensor & Tensor::t_() const { + return TypeDefault::t_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::t_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::tan() const { +@@ -2194,7 +2234,7 @@ inline Tensor Tensor::tan() const { + return TypeDefault::tan(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::tan(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::tan_() const { +@@ -2208,7 +2248,7 @@ inline Tensor & Tensor::tan_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::tan_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::tanh() const { +@@ -2216,7 +2256,7 @@ inline Tensor Tensor::tanh() const { + return TypeDefault::tanh(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::tanh(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::tanh_() const { +@@ -2230,7 +2270,7 @@ inline Tensor & Tensor::tanh_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::tanh_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const { +@@ -2238,7 +2278,7 @@ inline Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const { + return TypeDefault::transpose(const_cast(*this), dim0, dim1); + #else + static auto table = globalATenDispatch().getOpTable("aten::transpose(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dim0, dim1); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim0, dim1); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -2247,7 +2287,7 @@ inline Tensor Tensor::transpose(Dimname dim0, Dimname dim1) const { + return TypeDefault::transpose(const_cast(*this), dim0, dim1); + #else + static auto table = globalATenDispatch().getOpTable("aten::transpose(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dim0, dim1); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim0, dim1); + #endif + } + #endif +@@ -2256,7 +2296,7 @@ inline Tensor & Tensor::transpose_(int64_t dim0, int64_t dim1) const { + return TypeDefault::transpose_(const_cast(*this), dim0, dim1); + #else + static auto table = globalATenDispatch().getOpTable("aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim0, dim1); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim0, dim1); + #endif + } + inline Tensor Tensor::flip(IntArrayRef dims) const { +@@ -2270,7 +2310,7 @@ inline Tensor Tensor::flip(IntArrayRef dims) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::flip(Tensor self, int[] dims) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dims); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dims); + #endif + } + inline Tensor Tensor::roll(IntArrayRef shifts, IntArrayRef dims) const { +@@ -2284,7 +2324,7 @@ inline Tensor Tensor::roll(IntArrayRef shifts, IntArrayRef dims) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), shifts, dims); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), shifts, dims); + #endif + } + inline Tensor Tensor::rot90(int64_t k, IntArrayRef dims) const { +@@ -2292,7 +2332,7 @@ inline Tensor Tensor::rot90(int64_t k, IntArrayRef dims) const { + return TypeDefault::rot90(const_cast(*this), k, dims); + #else + static auto table = globalATenDispatch().getOpTable("aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), k, dims); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), k, dims); + #endif + } + inline Tensor Tensor::trunc() const { +@@ -2300,7 +2340,7 @@ inline Tensor Tensor::trunc() const { + return TypeDefault::trunc(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::trunc(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::trunc_() const { +@@ -2314,7 +2354,7 @@ inline Tensor & Tensor::trunc_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::trunc_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::type_as(const Tensor & other) const { +@@ -2322,7 +2362,7 @@ inline Tensor Tensor::type_as(const Tensor & other) const { + return TypeDefault::type_as(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::type_as(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::unsqueeze(int64_t dim) const { +@@ -2330,7 +2370,7 @@ inline Tensor Tensor::unsqueeze(int64_t dim) const { + return TypeDefault::unsqueeze(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + inline Tensor & Tensor::unsqueeze_(int64_t dim) const { +@@ -2338,7 +2378,7 @@ inline Tensor & Tensor::unsqueeze_(int64_t dim) const { + return TypeDefault::unsqueeze_(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + inline Tensor Tensor::var(bool unbiased) const { +@@ -2346,7 +2386,7 @@ inline Tensor Tensor::var(bool unbiased) const { + return TypeDefault::var(const_cast(*this), unbiased); + #else + static auto table = globalATenDispatch().getOpTable("aten::var(Tensor self, bool unbiased=True) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), unbiased); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), unbiased); + #endif + } + inline Tensor Tensor::var(IntArrayRef dim, bool unbiased, bool keepdim) const { +@@ -2354,7 +2394,7 @@ inline Tensor Tensor::var(IntArrayRef dim, bool unbiased, bool keepdim) const { + return TypeDefault::var(const_cast(*this), dim, unbiased, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::var.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, unbiased, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -2363,7 +2403,7 @@ inline Tensor Tensor::var(DimnameList dim, bool unbiased, bool keepdim) const { + return TypeDefault::var(const_cast(*this), dim, unbiased, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, unbiased, keepdim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); + #endif + } + #endif +@@ -2372,7 +2412,7 @@ inline Tensor Tensor::view_as(const Tensor & other) const { + return TypeDefault::view_as(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::view_as(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) const { +@@ -2380,7 +2420,7 @@ inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) cons + return TypeDefault::where(condition, const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(condition, const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(condition, *this, other))(condition, const_cast(*this), other); + #endif + } + inline Tensor Tensor::norm(c10::optional p, ScalarType dtype) const { +@@ -2388,7 +2428,7 @@ inline Tensor Tensor::norm(c10::optional p, ScalarType dtype) const { + return TypeDefault::norm(const_cast(*this), p, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor"); +- return table->getOp, ScalarType)>(type_set())(const_cast(*this), p, dtype); ++ return table->getOp, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dtype); + #endif + } + inline Tensor Tensor::norm(Scalar p) const { +@@ -2396,7 +2436,7 @@ inline Tensor Tensor::norm(Scalar p) const { + return TypeDefault::norm(const_cast(*this), p); + #else + static auto table = globalATenDispatch().getOpTable("aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), p); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p); + #endif + } + inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdim, ScalarType dtype) const { +@@ -2404,7 +2444,7 @@ inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdi + return TypeDefault::norm(const_cast(*this), p, dim, keepdim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"); +- return table->getOp, IntArrayRef, bool, ScalarType)>(type_set())(const_cast(*this), p, dim, keepdim, dtype); ++ return table->getOp, IntArrayRef, bool, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim, dtype); + #endif + } + inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdim) const { +@@ -2412,7 +2452,7 @@ inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdi + return TypeDefault::norm(const_cast(*this), p, dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor"); +- return table->getOp, IntArrayRef, bool)>(type_set())(const_cast(*this), p, dim, keepdim); ++ return table->getOp, IntArrayRef, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -2421,7 +2461,7 @@ inline Tensor Tensor::norm(c10::optional p, DimnameList dim, bool keepdi + return TypeDefault::norm(const_cast(*this), p, dim, keepdim, dtype); + #else + static auto table = globalATenDispatch().getOpTable("aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"); +- return table->getOp, DimnameList, bool, ScalarType)>(type_set())(const_cast(*this), p, dim, keepdim, dtype); ++ return table->getOp, DimnameList, bool, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim, dtype); + #endif + } + #endif +@@ -2431,7 +2471,7 @@ inline Tensor Tensor::norm(c10::optional p, DimnameList dim, bool keepdi + return TypeDefault::norm(const_cast(*this), p, dim, keepdim); + #else + static auto table = globalATenDispatch().getOpTable("aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor"); +- return table->getOp, DimnameList, bool)>(type_set())(const_cast(*this), p, dim, keepdim); ++ return table->getOp, DimnameList, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim); + #endif + } + #endif +@@ -2452,7 +2492,7 @@ inline Tensor Tensor::clone() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::clone(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::resize_as_(const Tensor & the_template) const { +@@ -2469,7 +2509,7 @@ inline Tensor & Tensor::resize_as_(const Tensor & the_template) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::resize_as_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), the_template); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, the_template))(const_cast(*this), the_template); + #endif + } + inline Tensor Tensor::pow(Scalar exponent) const { +@@ -2486,7 +2526,7 @@ inline Tensor Tensor::pow(Scalar exponent) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), exponent); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), exponent); + #endif + } + inline Tensor & Tensor::zero_() const { +@@ -2503,23 +2543,41 @@ inline Tensor & Tensor::zero_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::zero_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::sub(const Tensor & other, Scalar alpha) const { + #ifdef USE_STATIC_DISPATCH +- return TypeDefault::sub(const_cast(*this), other, alpha); ++ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { ++ case Backend::CPU: ++ return CPUType::sub(const_cast(*this), other, alpha); ++ break; ++ case Backend::SparseCPU: ++ return SparseCPUType::sub(const_cast(*this), other, alpha); ++ break; ++ default: ++ AT_ERROR("sub not implemented for ", at::toString(type_set())); ++ } + #else + static auto table = globalATenDispatch().getOpTable("aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, alpha); + #endif + } + inline Tensor & Tensor::sub_(const Tensor & other, Scalar alpha) const { + #ifdef USE_STATIC_DISPATCH +- return TypeDefault::sub_(const_cast(*this), other, alpha); ++ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { ++ case Backend::CPU: ++ return CPUType::sub_(const_cast(*this), other, alpha); ++ break; ++ case Backend::SparseCPU: ++ return SparseCPUType::sub_(const_cast(*this), other, alpha); ++ break; ++ default: ++ AT_ERROR("sub_ not implemented for ", at::toString(type_set())); ++ } + #else + static auto table = globalATenDispatch().getOpTable("aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, alpha); + #endif + } + inline Tensor Tensor::sub(Scalar other, Scalar alpha) const { +@@ -2527,7 +2585,7 @@ inline Tensor Tensor::sub(Scalar other, Scalar alpha) const { + return TypeDefault::sub(const_cast(*this), other, alpha); + #else + static auto table = globalATenDispatch().getOpTable("aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other, alpha); + #endif + } + inline Tensor & Tensor::sub_(Scalar other, Scalar alpha) const { +@@ -2535,23 +2593,41 @@ inline Tensor & Tensor::sub_(Scalar other, Scalar alpha) const { + return TypeDefault::sub_(const_cast(*this), other, alpha); + #else + static auto table = globalATenDispatch().getOpTable("aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other, alpha); + #endif + } + inline Tensor Tensor::addmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { + #ifdef USE_STATIC_DISPATCH +- return TypeDefault::addmm(const_cast(*this), mat1, mat2, beta, alpha); ++ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { ++ case Backend::CPU: ++ return CPUType::addmm(const_cast(*this), mat1, mat2, beta, alpha); ++ break; ++ case Backend::SparseCPU: ++ return SparseCPUType::addmm(const_cast(*this), mat1, mat2, beta, alpha); ++ break; ++ default: ++ AT_ERROR("addmm not implemented for ", at::toString(type_set())); ++ } + #else + static auto table = globalATenDispatch().getOpTable("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mat1, mat2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))(const_cast(*this), mat1, mat2, beta, alpha); + #endif + } + inline Tensor & Tensor::addmm_(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { + #ifdef USE_STATIC_DISPATCH +- return TypeDefault::addmm_(const_cast(*this), mat1, mat2, beta, alpha); ++ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { ++ case Backend::CPU: ++ return CPUType::addmm_(const_cast(*this), mat1, mat2, beta, alpha); ++ break; ++ case Backend::SparseCPU: ++ return SparseCPUType::addmm_(const_cast(*this), mat1, mat2, beta, alpha); ++ break; ++ default: ++ AT_ERROR("addmm_ not implemented for ", at::toString(type_set())); ++ } + #else + static auto table = globalATenDispatch().getOpTable("aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), mat1, mat2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))(const_cast(*this), mat1, mat2, beta, alpha); + #endif + } + inline Tensor & Tensor::sparse_resize_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { +@@ -2565,7 +2641,7 @@ inline Tensor & Tensor::sparse_resize_(IntArrayRef size, int64_t sparse_dim, int + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), size, sparse_dim, dense_dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, sparse_dim, dense_dim); + #endif + } + inline Tensor & Tensor::sparse_resize_and_clear_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { +@@ -2579,21 +2655,21 @@ inline Tensor & Tensor::sparse_resize_and_clear_(IntArrayRef size, int64_t spars + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), size, sparse_dim, dense_dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, sparse_dim, dense_dim); + #endif + } + inline Tensor Tensor::sparse_mask(const Tensor & mask) const { + #ifdef USE_STATIC_DISPATCH + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { +- case Backend::CPU: +- return CPUType::sparse_mask(const_cast(*this), mask); ++ case Backend::SparseCPU: ++ return SparseCPUType::sparse_mask(const_cast(*this), mask); + break; + default: + AT_ERROR("sparse_mask not implemented for ", at::toString(type_set())); + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sparse_mask(Tensor self, Tensor mask) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mask); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask))(const_cast(*this), mask); + #endif + } + inline Tensor Tensor::to_dense() const { +@@ -2607,7 +2683,7 @@ inline Tensor Tensor::to_dense() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::to_dense(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline int64_t Tensor::sparse_dim() const { +@@ -2621,7 +2697,7 @@ inline int64_t Tensor::sparse_dim() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sparse_dim(Tensor self) -> int"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline int64_t Tensor::_dimI() const { +@@ -2635,7 +2711,7 @@ inline int64_t Tensor::_dimI() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::_dimI(Tensor self) -> int"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline int64_t Tensor::dense_dim() const { +@@ -2649,7 +2725,7 @@ inline int64_t Tensor::dense_dim() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::dense_dim(Tensor self) -> int"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline int64_t Tensor::_dimV() const { +@@ -2663,7 +2739,7 @@ inline int64_t Tensor::_dimV() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::_dimV(Tensor self) -> int"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline int64_t Tensor::_nnz() const { +@@ -2677,7 +2753,7 @@ inline int64_t Tensor::_nnz() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::_nnz(Tensor self) -> int"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::coalesce() const { +@@ -2691,7 +2767,7 @@ inline Tensor Tensor::coalesce() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::coalesce(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline bool Tensor::is_coalesced() const { +@@ -2705,7 +2781,7 @@ inline bool Tensor::is_coalesced() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::is_coalesced(Tensor self) -> bool"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::_indices() const { +@@ -2719,7 +2795,7 @@ inline Tensor Tensor::_indices() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::_indices(Tensor(a) self) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::_values() const { +@@ -2733,7 +2809,7 @@ inline Tensor Tensor::_values() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::_values(Tensor(a) self) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::_coalesced_(bool coalesced) const { +@@ -2747,7 +2823,7 @@ inline Tensor & Tensor::_coalesced_(bool coalesced) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), coalesced); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), coalesced); + #endif + } + inline Tensor Tensor::indices() const { +@@ -2761,7 +2837,7 @@ inline Tensor Tensor::indices() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::indices(Tensor(a) self) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::values() const { +@@ -2775,7 +2851,7 @@ inline Tensor Tensor::values() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::values(Tensor(a) self) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline int64_t Tensor::numel() const { +@@ -2783,7 +2859,7 @@ inline int64_t Tensor::numel() const { + return TypeDefault::numel(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::numel(Tensor self) -> int"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline std::vector Tensor::unbind(int64_t dim) const { +@@ -2791,7 +2867,7 @@ inline std::vector Tensor::unbind(int64_t dim) const { + return TypeDefault::unbind(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::unbind(Tensor(a) self, int dim=0) -> Tensor(a)[]"); +- return table->getOp (const Tensor &, int64_t)>(type_set())(const_cast(*this), dim); ++ return table->getOp (const Tensor &, int64_t)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + #ifdef BUILD_NAMEDTENSOR +@@ -2800,7 +2876,7 @@ inline std::vector Tensor::unbind(Dimname dim) const { + return TypeDefault::unbind(const_cast(*this), dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::unbind(Tensor(a) self, Dimname dim) -> Tensor(a)[]"); +- return table->getOp (const Tensor &, Dimname)>(type_set())(const_cast(*this), dim); ++ return table->getOp (const Tensor &, Dimname)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); + #endif + } + #endif +@@ -2815,7 +2891,7 @@ inline Tensor Tensor::to_sparse(int64_t sparse_dim) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), sparse_dim); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), sparse_dim); + #endif + } + inline Tensor Tensor::to_sparse() const { +@@ -2829,7 +2905,7 @@ inline Tensor Tensor::to_sparse() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::to_sparse(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::to_mkldnn() const { +@@ -2843,7 +2919,7 @@ inline Tensor Tensor::to_mkldnn() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::to_mkldnn(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::dequantize() const { +@@ -2857,7 +2933,7 @@ inline Tensor Tensor::dequantize() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::dequantize(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline double Tensor::q_scale() const { +@@ -2871,7 +2947,7 @@ inline double Tensor::q_scale() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::q_scale(Tensor self) -> float"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline int64_t Tensor::q_zero_point() const { +@@ -2885,7 +2961,7 @@ inline int64_t Tensor::q_zero_point() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::q_zero_point(Tensor self) -> int"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::q_per_channel_scales() const { +@@ -2899,7 +2975,7 @@ inline Tensor Tensor::q_per_channel_scales() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_scales(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::q_per_channel_zero_points() const { +@@ -2913,7 +2989,7 @@ inline Tensor Tensor::q_per_channel_zero_points() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_zero_points(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::int_repr() const { +@@ -2927,7 +3003,7 @@ inline Tensor Tensor::int_repr() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::int_repr(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline QScheme Tensor::qscheme() const { +@@ -2941,7 +3017,7 @@ inline QScheme Tensor::qscheme() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::qscheme(Tensor self) -> QScheme"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool copy) const { +@@ -2949,7 +3025,7 @@ inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool + return TypeDefault::to(const_cast(*this), options, non_blocking, copy); + #else + static auto table = globalATenDispatch().getOpTable("aten::to.dtype_layout(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), options, non_blocking, copy); ++ return table->getOp(type_set(/* HMMMM */))(const_cast(*this), options, non_blocking, copy); + #endif + } + inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, bool copy) const { +@@ -2957,7 +3033,7 @@ inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, boo + return TypeDefault::to(const_cast(*this), device, dtype, non_blocking, copy); + #else + static auto table = globalATenDispatch().getOpTable("aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), device, dtype, non_blocking, copy); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), device, dtype, non_blocking, copy); + #endif + } + inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const { +@@ -2965,7 +3041,7 @@ inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const { + return TypeDefault::to(const_cast(*this), dtype, non_blocking, copy); + #else + static auto table = globalATenDispatch().getOpTable("aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dtype, non_blocking, copy); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype, non_blocking, copy); + #endif + } + inline Tensor Tensor::to(const Tensor & other, bool non_blocking, bool copy) const { +@@ -2973,7 +3049,7 @@ inline Tensor Tensor::to(const Tensor & other, bool non_blocking, bool copy) con + return TypeDefault::to(const_cast(*this), other, non_blocking, copy); + #else + static auto table = globalATenDispatch().getOpTable("aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other, non_blocking, copy); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, non_blocking, copy); + #endif + } + inline Scalar Tensor::item() const { +@@ -2981,7 +3057,7 @@ inline Scalar Tensor::item() const { + return TypeDefault::item(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::item(Tensor self) -> Scalar"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::set_(Storage source) const { +@@ -2995,7 +3071,7 @@ inline Tensor & Tensor::set_(Storage source) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), source); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), source); + #endif + } + inline Tensor & Tensor::set_(Storage source, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) const { +@@ -3012,7 +3088,7 @@ inline Tensor & Tensor::set_(Storage source, int64_t storage_offset, IntArrayRef + } + #else + static auto table = globalATenDispatch().getOpTable("aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), source, storage_offset, size, stride); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), source, storage_offset, size, stride); + #endif + } + inline Tensor & Tensor::set_(const Tensor & source) const { +@@ -3026,7 +3102,7 @@ inline Tensor & Tensor::set_(const Tensor & source) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), source); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, source))(const_cast(*this), source); + #endif + } + inline Tensor & Tensor::set_() const { +@@ -3040,7 +3116,7 @@ inline Tensor & Tensor::set_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::set_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::set_quantizer_(ConstQuantizerPtr quantizer) const { +@@ -3054,7 +3130,7 @@ inline Tensor & Tensor::set_quantizer_(ConstQuantizerPtr quantizer) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::set_quantizer_(Tensor(a!) self, ConstQuantizerPtr quantizer) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), quantizer); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), quantizer); + #endif + } + inline bool Tensor::is_set_to(const Tensor & tensor) const { +@@ -3068,7 +3144,7 @@ inline bool Tensor::is_set_to(const Tensor & tensor) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::is_set_to(Tensor self, Tensor tensor) -> bool"); +- return table->getOp(type_set())(const_cast(*this), tensor); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor))(const_cast(*this), tensor); + #endif + } + inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) const { +@@ -3082,7 +3158,7 @@ inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), mask, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask))(const_cast(*this), mask, value); + #endif + } + inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const { +@@ -3090,7 +3166,7 @@ inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const { + return TypeDefault::masked_fill(const_cast(*this), mask, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mask, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask))(const_cast(*this), mask, value); + #endif + } + inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) const { +@@ -3104,7 +3180,7 @@ inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) + } + #else + static auto table = globalATenDispatch().getOpTable("aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), mask, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask, value))(const_cast(*this), mask, value); + #endif + } + inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) const { +@@ -3112,7 +3188,7 @@ inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) con + return TypeDefault::masked_fill(const_cast(*this), mask, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mask, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask, value))(const_cast(*this), mask, value); + #endif + } + inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & source) const { +@@ -3126,7 +3202,7 @@ inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & sour + } + #else + static auto table = globalATenDispatch().getOpTable("aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), mask, source); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask, source))(const_cast(*this), mask, source); + #endif + } + inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const { +@@ -3134,7 +3210,7 @@ inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) + return TypeDefault::masked_scatter(const_cast(*this), mask, source); + #else + static auto table = globalATenDispatch().getOpTable("aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mask, source); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask, source))(const_cast(*this), mask, source); + #endif + } + inline Tensor Tensor::view(IntArrayRef size) const { +@@ -3151,7 +3227,7 @@ inline Tensor Tensor::view(IntArrayRef size) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::view(Tensor(a) self, int[] size) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), size); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size); + #endif + } + inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool accumulate) const { +@@ -3165,7 +3241,7 @@ inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool a + } + #else + static auto table = globalATenDispatch().getOpTable("aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), index, source, accumulate); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), index, source, accumulate); + #endif + } + inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tensor & source) const { +@@ -3179,7 +3255,7 @@ inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tens + } + #else + static auto table = globalATenDispatch().getOpTable("aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim, index, source); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), dim, index, source); + #endif + } + inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const { +@@ -3187,7 +3263,7 @@ inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor + return TypeDefault::index_add(const_cast(*this), dim, index, source); + #else + static auto table = globalATenDispatch().getOpTable("aten::index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index, source); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), dim, index, source); + #endif + } + inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar value) const { +@@ -3201,7 +3277,7 @@ inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar va + } + #else + static auto table = globalATenDispatch().getOpTable("aten::index_fill_.Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim, index, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, value); + #endif + } + inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value) const { +@@ -3209,7 +3285,7 @@ inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value + return TypeDefault::index_fill(const_cast(*this), dim, index, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::index_fill.Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, value); + #endif + } + inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Tensor & value) const { +@@ -3223,7 +3299,7 @@ inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Ten + } + #else + static auto table = globalATenDispatch().getOpTable("aten::index_fill_.Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim, index, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, value))(const_cast(*this), dim, index, value); + #endif + } + inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor & value) const { +@@ -3231,7 +3307,7 @@ inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor + return TypeDefault::index_fill(const_cast(*this), dim, index, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::index_fill.Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, value))(const_cast(*this), dim, index, value); + #endif + } + inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor & src) const { +@@ -3245,7 +3321,7 @@ inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor + } + #else + static auto table = globalATenDispatch().getOpTable("aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim, index, src); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, src))(const_cast(*this), dim, index, src); + #endif + } + inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & src) const { +@@ -3253,7 +3329,7 @@ inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & + return TypeDefault::scatter(const_cast(*this), dim, index, src); + #else + static auto table = globalATenDispatch().getOpTable("aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index, src); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, src))(const_cast(*this), dim, index, src); + #endif + } + inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value) const { +@@ -3267,7 +3343,7 @@ inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value + } + #else + static auto table = globalATenDispatch().getOpTable("aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim, index, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, value); + #endif + } + inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) const { +@@ -3275,7 +3351,7 @@ inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) c + return TypeDefault::scatter(const_cast(*this), dim, index, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, value); + #endif + } + inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Tensor & src) const { +@@ -3289,7 +3365,7 @@ inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Te + } + #else + static auto table = globalATenDispatch().getOpTable("aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), dim, index, src); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, src))(const_cast(*this), dim, index, src); + #endif + } + inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const { +@@ -3297,7 +3373,7 @@ inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tenso + return TypeDefault::scatter_add(const_cast(*this), dim, index, src); + #else + static auto table = globalATenDispatch().getOpTable("aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index, src); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, src))(const_cast(*this), dim, index, src); + #endif + } + inline Tensor & Tensor::lt_(Scalar other) const { +@@ -3311,7 +3387,7 @@ inline Tensor & Tensor::lt_(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::lt_(const Tensor & other) const { +@@ -3325,7 +3401,7 @@ inline Tensor & Tensor::lt_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::gt_(Scalar other) const { +@@ -3339,7 +3415,7 @@ inline Tensor & Tensor::gt_(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::gt_(const Tensor & other) const { +@@ -3353,7 +3429,7 @@ inline Tensor & Tensor::gt_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::le_(Scalar other) const { +@@ -3367,7 +3443,7 @@ inline Tensor & Tensor::le_(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::le_(const Tensor & other) const { +@@ -3381,7 +3457,7 @@ inline Tensor & Tensor::le_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::ge_(Scalar other) const { +@@ -3395,7 +3471,7 @@ inline Tensor & Tensor::ge_(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::ge_(const Tensor & other) const { +@@ -3409,7 +3485,7 @@ inline Tensor & Tensor::ge_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::eq_(Scalar other) const { +@@ -3423,7 +3499,7 @@ inline Tensor & Tensor::eq_(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::eq_(const Tensor & other) const { +@@ -3437,7 +3513,7 @@ inline Tensor & Tensor::eq_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::ne_(Scalar other) const { +@@ -3451,7 +3527,7 @@ inline Tensor & Tensor::ne_(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::ne_(const Tensor & other) const { +@@ -3465,7 +3541,7 @@ inline Tensor & Tensor::ne_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__and__(Scalar other) const { +@@ -3479,7 +3555,7 @@ inline Tensor Tensor::__and__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__and__(const Tensor & other) const { +@@ -3493,7 +3569,7 @@ inline Tensor Tensor::__and__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__iand__(Scalar other) const { +@@ -3507,7 +3583,7 @@ inline Tensor & Tensor::__iand__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__iand__(const Tensor & other) const { +@@ -3521,7 +3597,7 @@ inline Tensor & Tensor::__iand__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__or__(Scalar other) const { +@@ -3535,7 +3611,7 @@ inline Tensor Tensor::__or__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__or__(const Tensor & other) const { +@@ -3549,7 +3625,7 @@ inline Tensor Tensor::__or__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__ior__(Scalar other) const { +@@ -3563,7 +3639,7 @@ inline Tensor & Tensor::__ior__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__ior__(const Tensor & other) const { +@@ -3577,7 +3653,7 @@ inline Tensor & Tensor::__ior__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__xor__(Scalar other) const { +@@ -3591,7 +3667,7 @@ inline Tensor Tensor::__xor__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__xor__(const Tensor & other) const { +@@ -3605,7 +3681,7 @@ inline Tensor Tensor::__xor__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__ixor__(Scalar other) const { +@@ -3619,7 +3695,7 @@ inline Tensor & Tensor::__ixor__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__ixor__(const Tensor & other) const { +@@ -3633,7 +3709,7 @@ inline Tensor & Tensor::__ixor__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__lshift__(Scalar other) const { +@@ -3647,7 +3723,7 @@ inline Tensor Tensor::__lshift__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__lshift__(const Tensor & other) const { +@@ -3661,7 +3737,7 @@ inline Tensor Tensor::__lshift__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__ilshift__(Scalar other) const { +@@ -3675,7 +3751,7 @@ inline Tensor & Tensor::__ilshift__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__ilshift__(const Tensor & other) const { +@@ -3689,7 +3765,7 @@ inline Tensor & Tensor::__ilshift__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__rshift__(Scalar other) const { +@@ -3703,7 +3779,7 @@ inline Tensor Tensor::__rshift__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::__rshift__(const Tensor & other) const { +@@ -3717,7 +3793,7 @@ inline Tensor Tensor::__rshift__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__irshift__(Scalar other) const { +@@ -3731,7 +3807,7 @@ inline Tensor & Tensor::__irshift__(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::__irshift__(const Tensor & other) const { +@@ -3745,7 +3821,7 @@ inline Tensor & Tensor::__irshift__(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::lgamma_() const { +@@ -3759,7 +3835,7 @@ inline Tensor & Tensor::lgamma_() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lgamma_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::atan2_(const Tensor & other) const { +@@ -3767,7 +3843,7 @@ inline Tensor & Tensor::atan2_(const Tensor & other) const { + return TypeDefault::atan2_(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::tril_(int64_t diagonal) const { +@@ -3781,7 +3857,7 @@ inline Tensor & Tensor::tril_(int64_t diagonal) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), diagonal); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); + #endif + } + inline Tensor & Tensor::triu_(int64_t diagonal) const { +@@ -3795,7 +3871,7 @@ inline Tensor & Tensor::triu_(int64_t diagonal) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), diagonal); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); + #endif + } + inline Tensor & Tensor::digamma_() const { +@@ -3803,7 +3879,7 @@ inline Tensor & Tensor::digamma_() const { + return TypeDefault::digamma_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::digamma_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::polygamma_(int64_t n) const { +@@ -3811,7 +3887,7 @@ inline Tensor & Tensor::polygamma_(int64_t n) const { + return TypeDefault::polygamma_(const_cast(*this), n); + #else + static auto table = globalATenDispatch().getOpTable("aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), n); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), n); + #endif + } + inline Tensor & Tensor::renorm_(Scalar p, int64_t dim, Scalar maxnorm) const { +@@ -3825,7 +3901,7 @@ inline Tensor & Tensor::renorm_(Scalar p, int64_t dim, Scalar maxnorm) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), p, dim, maxnorm); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, maxnorm); + #endif + } + inline Tensor & Tensor::pow_(Scalar exponent) const { +@@ -3839,7 +3915,7 @@ inline Tensor & Tensor::pow_(Scalar exponent) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), exponent); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), exponent); + #endif + } + inline Tensor & Tensor::pow_(const Tensor & exponent) const { +@@ -3853,7 +3929,7 @@ inline Tensor & Tensor::pow_(const Tensor & exponent) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), exponent); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, exponent))(const_cast(*this), exponent); + #endif + } + inline Tensor & Tensor::lerp_(const Tensor & end, Scalar weight) const { +@@ -3867,7 +3943,7 @@ inline Tensor & Tensor::lerp_(const Tensor & end, Scalar weight) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), end, weight); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, end))(const_cast(*this), end, weight); + #endif + } + inline Tensor & Tensor::lerp_(const Tensor & end, const Tensor & weight) const { +@@ -3881,7 +3957,7 @@ inline Tensor & Tensor::lerp_(const Tensor & end, const Tensor & weight) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), end, weight); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, end, weight))(const_cast(*this), end, weight); + #endif + } + inline Tensor & Tensor::fmod_(Scalar other) const { +@@ -3895,7 +3971,7 @@ inline Tensor & Tensor::fmod_(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::fmod_(const Tensor & other) const { +@@ -3909,7 +3985,7 @@ inline Tensor & Tensor::fmod_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::remainder_(Scalar other) const { +@@ -3923,7 +3999,7 @@ inline Tensor & Tensor::remainder_(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::remainder_(const Tensor & other) const { +@@ -3937,7 +4013,7 @@ inline Tensor & Tensor::remainder_(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor & Tensor::addbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { +@@ -3951,7 +4027,7 @@ inline Tensor & Tensor::addbmm_(const Tensor & batch1, const Tensor & batch2, Sc + } + #else + static auto table = globalATenDispatch().getOpTable("aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), batch1, batch2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))(const_cast(*this), batch1, batch2, beta, alpha); + #endif + } + inline Tensor Tensor::addbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { +@@ -3965,7 +4041,7 @@ inline Tensor Tensor::addbmm(const Tensor & batch1, const Tensor & batch2, Scala + } + #else + static auto table = globalATenDispatch().getOpTable("aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), batch1, batch2, beta, alpha); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))(const_cast(*this), batch1, batch2, beta, alpha); + #endif + } + inline Tensor & Tensor::addcdiv_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { +@@ -3973,7 +4049,7 @@ inline Tensor & Tensor::addcdiv_(const Tensor & tensor1, const Tensor & tensor2, + return TypeDefault::addcdiv_(const_cast(*this), tensor1, tensor2, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), tensor1, tensor2, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))(const_cast(*this), tensor1, tensor2, value); + #endif + } + inline Tensor & Tensor::random_(int64_t from, int64_t to, Generator * generator) const { +@@ -3987,7 +4063,7 @@ inline Tensor & Tensor::random_(int64_t from, int64_t to, Generator * generator) + } + #else + static auto table = globalATenDispatch().getOpTable("aten::random_.from(Tensor(a!) self, int from, int to, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), from, to, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), from, to, generator); + #endif + } + inline Tensor & Tensor::random_(int64_t to, Generator * generator) const { +@@ -4001,7 +4077,7 @@ inline Tensor & Tensor::random_(int64_t to, Generator * generator) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), to, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), to, generator); + #endif + } + inline Tensor & Tensor::random_(Generator * generator) const { +@@ -4015,7 +4091,7 @@ inline Tensor & Tensor::random_(Generator * generator) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), generator); + #endif + } + inline Tensor & Tensor::uniform_(double from, double to, Generator * generator) const { +@@ -4029,7 +4105,7 @@ inline Tensor & Tensor::uniform_(double from, double to, Generator * generator) + } + #else + static auto table = globalATenDispatch().getOpTable("aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), from, to, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), from, to, generator); + #endif + } + inline Tensor & Tensor::normal_(double mean, double std, Generator * generator) const { +@@ -4043,7 +4119,7 @@ inline Tensor & Tensor::normal_(double mean, double std, Generator * generator) + } + #else + static auto table = globalATenDispatch().getOpTable("aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), mean, std, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), mean, std, generator); + #endif + } + inline Tensor & Tensor::cauchy_(double median, double sigma, Generator * generator) const { +@@ -4057,7 +4133,7 @@ inline Tensor & Tensor::cauchy_(double median, double sigma, Generator * generat + } + #else + static auto table = globalATenDispatch().getOpTable("aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), median, sigma, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), median, sigma, generator); + #endif + } + inline Tensor & Tensor::log_normal_(double mean, double std, Generator * generator) const { +@@ -4071,7 +4147,7 @@ inline Tensor & Tensor::log_normal_(double mean, double std, Generator * generat + } + #else + static auto table = globalATenDispatch().getOpTable("aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), mean, std, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), mean, std, generator); + #endif + } + inline Tensor & Tensor::exponential_(double lambd, Generator * generator) const { +@@ -4085,7 +4161,7 @@ inline Tensor & Tensor::exponential_(double lambd, Generator * generator) const + } + #else + static auto table = globalATenDispatch().getOpTable("aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), lambd, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), lambd, generator); + #endif + } + inline Tensor & Tensor::geometric_(double p, Generator * generator) const { +@@ -4099,7 +4175,7 @@ inline Tensor & Tensor::geometric_(double p, Generator * generator) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), p, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); + #endif + } + inline Tensor Tensor::diag(int64_t diagonal) const { +@@ -4113,7 +4189,7 @@ inline Tensor Tensor::diag(int64_t diagonal) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::diag(Tensor self, int diagonal=0) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), diagonal); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); + #endif + } + inline Tensor Tensor::cross(const Tensor & other, c10::optional dim) const { +@@ -4121,7 +4197,7 @@ inline Tensor Tensor::cross(const Tensor & other, c10::optional dim) co + return TypeDefault::cross(const_cast(*this), other, dim); + #else + static auto table = globalATenDispatch().getOpTable("aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor"); +- return table->getOp)>(type_set())(const_cast(*this), other, dim); ++ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, dim); + #endif + } + inline Tensor Tensor::triu(int64_t diagonal) const { +@@ -4129,7 +4205,7 @@ inline Tensor Tensor::triu(int64_t diagonal) const { + return TypeDefault::triu(const_cast(*this), diagonal); + #else + static auto table = globalATenDispatch().getOpTable("aten::triu(Tensor self, int diagonal=0) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), diagonal); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); + #endif + } + inline Tensor Tensor::tril(int64_t diagonal) const { +@@ -4137,7 +4213,7 @@ inline Tensor Tensor::tril(int64_t diagonal) const { + return TypeDefault::tril(const_cast(*this), diagonal); + #else + static auto table = globalATenDispatch().getOpTable("aten::tril(Tensor self, int diagonal=0) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), diagonal); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); + #endif + } + inline Tensor Tensor::trace() const { +@@ -4151,7 +4227,7 @@ inline Tensor Tensor::trace() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::trace(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::ne(Scalar other) const { +@@ -4168,7 +4244,7 @@ inline Tensor Tensor::ne(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ne.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::ne(const Tensor & other) const { +@@ -4185,7 +4261,7 @@ inline Tensor Tensor::ne(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ne.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::eq(Scalar other) const { +@@ -4202,7 +4278,7 @@ inline Tensor Tensor::eq(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::eq.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::eq(const Tensor & other) const { +@@ -4219,7 +4295,7 @@ inline Tensor Tensor::eq(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::eq.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::ge(Scalar other) const { +@@ -4236,7 +4312,7 @@ inline Tensor Tensor::ge(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ge.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::ge(const Tensor & other) const { +@@ -4253,7 +4329,7 @@ inline Tensor Tensor::ge(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ge.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::le(Scalar other) const { +@@ -4270,7 +4346,7 @@ inline Tensor Tensor::le(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::le.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::le(const Tensor & other) const { +@@ -4287,7 +4363,7 @@ inline Tensor Tensor::le(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::le.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::gt(Scalar other) const { +@@ -4304,7 +4380,7 @@ inline Tensor Tensor::gt(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::gt.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::gt(const Tensor & other) const { +@@ -4321,7 +4397,7 @@ inline Tensor Tensor::gt(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::gt.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::lt(Scalar other) const { +@@ -4338,7 +4414,7 @@ inline Tensor Tensor::lt(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lt.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::lt(const Tensor & other) const { +@@ -4355,7 +4431,7 @@ inline Tensor Tensor::lt(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lt.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::take(const Tensor & index) const { +@@ -4369,7 +4445,7 @@ inline Tensor Tensor::take(const Tensor & index) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::take(Tensor self, Tensor index) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), index); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), index); + #endif + } + inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const { +@@ -4386,7 +4462,7 @@ inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index); + #endif + } + inline Tensor Tensor::masked_select(const Tensor & mask) const { +@@ -4400,7 +4476,7 @@ inline Tensor Tensor::masked_select(const Tensor & mask) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::masked_select(Tensor self, Tensor mask) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), mask); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask))(const_cast(*this), mask); + #endif + } + inline Tensor Tensor::nonzero() const { +@@ -4414,7 +4490,7 @@ inline Tensor Tensor::nonzero() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::nonzero(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline std::vector Tensor::nonzero_numpy() const { +@@ -4422,7 +4498,7 @@ inline std::vector Tensor::nonzero_numpy() const { + return TypeDefault::nonzero_numpy(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::nonzero_numpy(Tensor self) -> Tensor[]"); +- return table->getOp (const Tensor &)>(type_set())(const_cast(*this)); ++ return table->getOp (const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad) const { +@@ -4436,7 +4512,7 @@ inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad + } + #else + static auto table = globalATenDispatch().getOpTable("aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, index, sparse_grad); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, sparse_grad); + #endif + } + inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { +@@ -4444,7 +4520,7 @@ inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Sc + return TypeDefault::addcmul(const_cast(*this), tensor1, tensor2, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), tensor1, tensor2, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))(const_cast(*this), tensor1, tensor2, value); + #endif + } + inline Tensor & Tensor::addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { +@@ -4452,7 +4528,7 @@ inline Tensor & Tensor::addcmul_(const Tensor & tensor1, const Tensor & tensor2, + return TypeDefault::addcmul_(const_cast(*this), tensor1, tensor2, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this), tensor1, tensor2, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))(const_cast(*this), tensor1, tensor2, value); + #endif + } + inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { +@@ -4460,7 +4536,7 @@ inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Sc + return TypeDefault::addcdiv(const_cast(*this), tensor1, tensor2, value); + #else + static auto table = globalATenDispatch().getOpTable("aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), tensor1, tensor2, value); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))(const_cast(*this), tensor1, tensor2, value); + #endif + } + inline std::tuple Tensor::lstsq(const Tensor & A) const { +@@ -4474,7 +4550,7 @@ inline std::tuple Tensor::lstsq(const Tensor & A) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lstsq(Tensor self, Tensor A) -> (Tensor solution, Tensor QR)"); +- return table->getOp (const Tensor &, const Tensor &)>(type_set())(const_cast(*this), A); ++ return table->getOp (const Tensor &, const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this, A))(const_cast(*this), A); + #endif + } + inline std::tuple Tensor::triangular_solve(const Tensor & A, bool upper, bool transpose, bool unitriangular) const { +@@ -4482,7 +4558,7 @@ inline std::tuple Tensor::triangular_solve(const Tensor & A, bool + return TypeDefault::triangular_solve(const_cast(*this), A, upper, transpose, unitriangular); + #else + static auto table = globalATenDispatch().getOpTable("aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)"); +- return table->getOp (const Tensor &, const Tensor &, bool, bool, bool)>(type_set())(const_cast(*this), A, upper, transpose, unitriangular); ++ return table->getOp (const Tensor &, const Tensor &, bool, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, A))(const_cast(*this), A, upper, transpose, unitriangular); + #endif + } + inline std::tuple Tensor::symeig(bool eigenvectors, bool upper) const { +@@ -4490,7 +4566,7 @@ inline std::tuple Tensor::symeig(bool eigenvectors, bool upper) c + return TypeDefault::symeig(const_cast(*this), eigenvectors, upper); + #else + static auto table = globalATenDispatch().getOpTable("aten::symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)"); +- return table->getOp (const Tensor &, bool, bool)>(type_set())(const_cast(*this), eigenvectors, upper); ++ return table->getOp (const Tensor &, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), eigenvectors, upper); + #endif + } + inline std::tuple Tensor::eig(bool eigenvectors) const { +@@ -4504,7 +4580,7 @@ inline std::tuple Tensor::eig(bool eigenvectors) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors)"); +- return table->getOp (const Tensor &, bool)>(type_set())(const_cast(*this), eigenvectors); ++ return table->getOp (const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), eigenvectors); + #endif + } + inline std::tuple Tensor::svd(bool some, bool compute_uv) const { +@@ -4512,7 +4588,7 @@ inline std::tuple Tensor::svd(bool some, bool compute_uv) + return TypeDefault::svd(const_cast(*this), some, compute_uv); + #else + static auto table = globalATenDispatch().getOpTable("aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)"); +- return table->getOp (const Tensor &, bool, bool)>(type_set())(const_cast(*this), some, compute_uv); ++ return table->getOp (const Tensor &, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), some, compute_uv); + #endif + } + inline Tensor Tensor::cholesky(bool upper) const { +@@ -4520,7 +4596,7 @@ inline Tensor Tensor::cholesky(bool upper) const { + return TypeDefault::cholesky(const_cast(*this), upper); + #else + static auto table = globalATenDispatch().getOpTable("aten::cholesky(Tensor self, bool upper=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), upper); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), upper); + #endif + } + inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const { +@@ -4528,7 +4604,7 @@ inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const { + return TypeDefault::cholesky_solve(const_cast(*this), input2, upper); + #else + static auto table = globalATenDispatch().getOpTable("aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), input2, upper); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, input2))(const_cast(*this), input2, upper); + #endif + } + inline std::tuple Tensor::solve(const Tensor & A) const { +@@ -4536,7 +4612,7 @@ inline std::tuple Tensor::solve(const Tensor & A) const { + return TypeDefault::solve(const_cast(*this), A); + #else + static auto table = globalATenDispatch().getOpTable("aten::solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU)"); +- return table->getOp (const Tensor &, const Tensor &)>(type_set())(const_cast(*this), A); ++ return table->getOp (const Tensor &, const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this, A))(const_cast(*this), A); + #endif + } + inline Tensor Tensor::cholesky_inverse(bool upper) const { +@@ -4550,7 +4626,7 @@ inline Tensor Tensor::cholesky_inverse(bool upper) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), upper); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), upper); + #endif + } + inline std::tuple Tensor::qr(bool some) const { +@@ -4558,7 +4634,7 @@ inline std::tuple Tensor::qr(bool some) const { + return TypeDefault::qr(const_cast(*this), some); + #else + static auto table = globalATenDispatch().getOpTable("aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)"); +- return table->getOp (const Tensor &, bool)>(type_set())(const_cast(*this), some); ++ return table->getOp (const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), some); + #endif + } + inline std::tuple Tensor::geqrf() const { +@@ -4572,7 +4648,7 @@ inline std::tuple Tensor::geqrf() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)"); +- return table->getOp (const Tensor &)>(type_set())(const_cast(*this)); ++ return table->getOp (const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::orgqr(const Tensor & input2) const { +@@ -4586,7 +4662,7 @@ inline Tensor Tensor::orgqr(const Tensor & input2) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::orgqr(Tensor self, Tensor input2) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), input2); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, input2))(const_cast(*this), input2); + #endif + } + inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const { +@@ -4600,7 +4676,7 @@ inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool l + } + #else + static auto table = globalATenDispatch().getOpTable("aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), input2, input3, left, transpose); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, input2, input3))(const_cast(*this), input2, input3, left, transpose); + #endif + } + inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const { +@@ -4608,7 +4684,7 @@ inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) + return TypeDefault::lu_solve(const_cast(*this), LU_data, LU_pivots); + #else + static auto table = globalATenDispatch().getOpTable("aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), LU_data, LU_pivots); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, LU_data, LU_pivots))(const_cast(*this), LU_data, LU_pivots); + #endif + } + inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generator * generator) const { +@@ -4622,7 +4698,7 @@ inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generat + } + #else + static auto table = globalATenDispatch().getOpTable("aten::multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), num_samples, replacement, generator); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), num_samples, replacement, generator); + #endif + } + inline Tensor Tensor::lgamma() const { +@@ -4636,7 +4712,7 @@ inline Tensor Tensor::lgamma() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lgamma(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::digamma() const { +@@ -4644,7 +4720,7 @@ inline Tensor Tensor::digamma() const { + return TypeDefault::digamma(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::digamma(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::polygamma(int64_t n) const { +@@ -4652,7 +4728,7 @@ inline Tensor Tensor::polygamma(int64_t n) const { + return TypeDefault::polygamma(n, const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::polygamma(int n, Tensor self) -> Tensor"); +- return table->getOp(type_set())(n, const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(n, const_cast(*this)); + #endif + } + inline Tensor Tensor::erfinv() const { +@@ -4660,7 +4736,7 @@ inline Tensor Tensor::erfinv() const { + return TypeDefault::erfinv(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::erfinv(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::erfinv_() const { +@@ -4668,7 +4744,7 @@ inline Tensor & Tensor::erfinv_() const { + return TypeDefault::erfinv_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::erfinv_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::sign() const { +@@ -4676,7 +4752,7 @@ inline Tensor Tensor::sign() const { + return TypeDefault::sign(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::sign(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor & Tensor::sign_() const { +@@ -4684,7 +4760,7 @@ inline Tensor & Tensor::sign_() const { + return TypeDefault::sign_(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::sign_(Tensor(a!) self) -> Tensor(a!)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::dist(const Tensor & other, Scalar p) const { +@@ -4698,7 +4774,7 @@ inline Tensor Tensor::dist(const Tensor & other, Scalar p) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other, p); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, p); + #endif + } + inline Tensor Tensor::atan2(const Tensor & other) const { +@@ -4706,7 +4782,7 @@ inline Tensor Tensor::atan2(const Tensor & other) const { + return TypeDefault::atan2(const_cast(*this), other); + #else + static auto table = globalATenDispatch().getOpTable("aten::atan2(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::lerp(const Tensor & end, Scalar weight) const { +@@ -4720,7 +4796,7 @@ inline Tensor Tensor::lerp(const Tensor & end, Scalar weight) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), end, weight); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, end))(const_cast(*this), end, weight); + #endif + } + inline Tensor Tensor::lerp(const Tensor & end, const Tensor & weight) const { +@@ -4734,7 +4810,7 @@ inline Tensor Tensor::lerp(const Tensor & end, const Tensor & weight) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), end, weight); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, end, weight))(const_cast(*this), end, weight); + #endif + } + inline Tensor Tensor::histc(int64_t bins, Scalar min, Scalar max) const { +@@ -4748,7 +4824,7 @@ inline Tensor Tensor::histc(int64_t bins, Scalar min, Scalar max) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), bins, min, max); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), bins, min, max); + #endif + } + inline Tensor Tensor::fmod(Scalar other) const { +@@ -4762,7 +4838,7 @@ inline Tensor Tensor::fmod(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::fmod(const Tensor & other) const { +@@ -4776,7 +4852,7 @@ inline Tensor Tensor::fmod(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::remainder(Scalar other) const { +@@ -4790,7 +4866,7 @@ inline Tensor Tensor::remainder(Scalar other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::remainder(const Tensor & other) const { +@@ -4804,7 +4880,7 @@ inline Tensor Tensor::remainder(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::min(const Tensor & other) const { +@@ -4818,7 +4894,7 @@ inline Tensor Tensor::min(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::min.other(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::min() const { +@@ -4835,7 +4911,7 @@ inline Tensor Tensor::min() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::min(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::max(const Tensor & other) const { +@@ -4849,7 +4925,7 @@ inline Tensor Tensor::max(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::max.other(Tensor self, Tensor other) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::max() const { +@@ -4866,7 +4942,7 @@ inline Tensor Tensor::max() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::max(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::median() const { +@@ -4880,7 +4956,7 @@ inline Tensor Tensor::median() const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::median(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline std::tuple Tensor::sort(int64_t dim, bool descending) const { +@@ -4897,7 +4973,7 @@ inline std::tuple Tensor::sort(int64_t dim, bool descending) cons + } + #else + static auto table = globalATenDispatch().getOpTable("aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, descending); ++ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, descending); + #endif + } + inline Tensor Tensor::argsort(int64_t dim, bool descending) const { +@@ -4905,7 +4981,7 @@ inline Tensor Tensor::argsort(int64_t dim, bool descending) const { + return TypeDefault::argsort(const_cast(*this), dim, descending); + #else + static auto table = globalATenDispatch().getOpTable("aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), dim, descending); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, descending); + #endif + } + inline std::tuple Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const { +@@ -4913,7 +4989,7 @@ inline std::tuple Tensor::topk(int64_t k, int64_t dim, bool large + return TypeDefault::topk(const_cast(*this), k, dim, largest, sorted); + #else + static auto table = globalATenDispatch().getOpTable("aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)"); +- return table->getOp (const Tensor &, int64_t, int64_t, bool, bool)>(type_set())(const_cast(*this), k, dim, largest, sorted); ++ return table->getOp (const Tensor &, int64_t, int64_t, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), k, dim, largest, sorted); + #endif + } + inline Tensor Tensor::all() const { +@@ -4921,7 +4997,7 @@ inline Tensor Tensor::all() const { + return TypeDefault::all(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::all(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::any() const { +@@ -4929,7 +5005,7 @@ inline Tensor Tensor::any() const { + return TypeDefault::any(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::any(Tensor self) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + inline Tensor Tensor::renorm(Scalar p, int64_t dim, Scalar maxnorm) const { +@@ -4943,7 +5019,7 @@ inline Tensor Tensor::renorm(Scalar p, int64_t dim, Scalar maxnorm) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), p, dim, maxnorm); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, maxnorm); + #endif + } + inline Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) const { +@@ -4957,7 +5033,7 @@ inline Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) cons + } + #else + static auto table = globalATenDispatch().getOpTable("aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this), dimension, size, step); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dimension, size, step); + #endif + } + inline bool Tensor::equal(const Tensor & other) const { +@@ -4974,7 +5050,7 @@ inline bool Tensor::equal(const Tensor & other) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::equal(Tensor self, Tensor other) -> bool"); +- return table->getOp(type_set())(const_cast(*this), other); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); + #endif + } + inline Tensor Tensor::pow(const Tensor & exponent) const { +@@ -4988,7 +5064,7 @@ inline Tensor Tensor::pow(const Tensor & exponent) const { + } + #else + static auto table = globalATenDispatch().getOpTable("aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor"); +- return table->getOp(type_set())(const_cast(*this), exponent); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, exponent))(const_cast(*this), exponent); + #endif + } + inline Tensor Tensor::alias() const { +@@ -4996,7 +5072,7 @@ inline Tensor Tensor::alias() const { + return TypeDefault::alias(const_cast(*this)); + #else + static auto table = globalATenDispatch().getOpTable("aten::alias(Tensor(a) self) -> Tensor(a)"); +- return table->getOp(type_set())(const_cast(*this)); ++ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); + #endif + } + +diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h +new file mode 100644 +index 0000000000..f1078ca9ce +--- /dev/null ++++ b/aten/src/ATen/core/Variadic.h +@@ -0,0 +1,74 @@ ++#pragma once ++ ++#include ++#include ++#include ++#include ++ ++namespace at { ++ ++// This class allows you to write variadic functions which ++// call a (possibly overloaded) function on each argument, ++// in order. This is most commonly used in autogenerated code, ++// where it is convenient to have a function that can uniformly ++// take arguments of different types. If your arguments ++// are homogenous consider using a std::initializer_list instead. ++// ++// For examples of this in use, see torch/csrc/utils/variadic.h ++template ++struct IterArgs { ++ template ++ inline F& apply() { ++ return self(); ++ } ++ ++ // NB: Use perfect forwarding here, otherwise we'll make value ++ // copies of all arguments! ++ template ++ inline F& apply(T&& arg, Args&&... args) { ++ self()(std::forward(arg)); ++ if (self().short_circuit()) { ++ return self(); ++ } else { ++ return apply(std::forward(args)...); ++ } ++ } ++ ++ // Here are some handy overloads which provide sensible ++ // defaults for container-like structures that one might ++ // be interested in recursing into. You can enable them ++ // by adding: ++ // ++ // using IterArgs::operator() ++ // ++ // to your struct. These are not enabled by default because ++ // you may be able to process these structures more efficiently ++ // than handling them one-by-one. ++ ++ template ++ void operator()(at::ArrayRef args) { ++ for (const auto& arg : args) { ++ self()(arg); ++ if (short_circuit()) ++ return; ++ } ++ } ++ ++ // NB: we need to specify std::vector manually as C++ won't ++ // do an implicit conversion to make a template deduction go through. ++ template ++ void operator()(const std::vector& args) { ++ self()(at::ArrayRef{args}); ++ } ++ ++ bool short_circuit() { ++ return false; ++ } ++ ++ private: ++ inline F& self() { ++ return *static_cast(this); ++ } ++}; ++ ++} // namespace torch +diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h +index e1feb58b80..dd07afe0e6 100644 +--- a/aten/src/ATen/core/aten_interned_strings.h ++++ b/aten/src/ATen/core/aten_interned_strings.h +@@ -579,7 +579,6 @@ _(aten, rrelu_with_noise) \ + _(aten, rrelu_with_noise_backward) \ + _(aten, rrelu_with_noise_forward) \ + _(aten, rsqrt) \ +-_(aten, s_native_addmm) \ + _(aten, scatter) \ + _(aten, scatter_add) \ + _(aten, select) \ +diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py +index 633c0ed1d9..597f0aad97 100644 +--- a/aten/src/ATen/function_wrapper.py ++++ b/aten/src/ATen/function_wrapper.py +@@ -129,14 +129,13 @@ TENSOR_METHOD_DECLARATION = CodeTemplate("""\ + ${return_type} ${api_name}(${method_formals_with_defaults}) const; + """) + # add non-virtual declaration to Tensor.cpp +-# TODO: This will need to be adjusted for multiple dispatch + TENSOR_METHOD_DEFINITION = CodeTemplate("""\ + inline ${return_type} Tensor::${api_name}(${method_formals}) const { + #ifdef USE_STATIC_DISPATCH + ${static_dispatch_method_body} + #else + static auto table = globalATenDispatch().getOpTable("${schema_string}"); +- return table->getOp<${return_type} (${formals_types})>(type_set())(${method_actuals}); ++ return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${method_actuals}); + #endif + } + """) +@@ -835,6 +834,19 @@ def create_generic(top_env, declarations): + + return None + ++ def find_multidispatch_tensors(formals): ++ # type: (List[AtFormal]) -> List[str] ++ # Compute the list of all tensor arguments which should be considered ++ # for multiple dispatch. Note that this doesn't completely replace ++ # find_dispatch_tensor because we use the "dispatch tensor" to determine ++ # device guards. This is ONLY used for multiple dispatch in ++ # ATenDispatch.h ++ r = [] ++ for formal in formals: ++ if 'TensorList' == formal['dynamic_type'] or is_any_tensor_type(formal): ++ r.append(formal['name']) ++ return r ++ + def format_formal(f): + # type: (AtFormal) -> str + return '{} {}'.format(f['type'], f['name']) +@@ -903,6 +915,8 @@ def create_generic(top_env, declarations): + + def process_option(option): + # type: (FunctionOption) -> None ++ # Mutably populate option with derived values computed from values ++ # passed in to option. + option['inplace'] = re.search( + '(^__i|[^_]_$)', option['api_name']) is not None + +@@ -1090,8 +1104,23 @@ def create_generic(top_env, declarations): + def has_named_tensor_formals(formals): + return any(['Dimname' in formal['dynamic_type'] for formal in formals]) + +- def gen_tensor_method(option): +- # type: (Any) -> FunctionCode ++ def gen_tensor_method(option, multidispatch_tensors): ++ # type: (Any, Optional[List[str]]) -> FunctionCode ++ # TODO: Swing this shared code to top level ++ if multidispatch_tensors: ++ def swizzle_self(t): # blegh ++ if t == 'self': ++ return '*this' ++ else: ++ return t ++ option['inferred_type_set'] = 'at::detail::multi_dispatch_tensor_type_set({})'.format( ++ ', '.join(swizzle_self(t) for t in multidispatch_tensors) ++ ) ++ else: ++ # TODO: Err, what? If we didn't trigger multidispatch_tensors ++ # codepath... how?! This is a method, surely something must be ++ # dispatching! ++ option['inferred_type_set'] = 'type_set(/* HMMMM */)' + if isinstance(type_method_dispatch, dict): + static_dispatch_function_switches = [] + # NB: As this code is currently written, there will NEVER be +@@ -1125,10 +1154,11 @@ def create_generic(top_env, declarations): + declaration=TENSOR_METHOD_DECLARATION.substitute(option, static_dispatch_method_body=static_dispatch_method_body), + definition=TENSOR_METHOD_DEFINITION.substitute(option, static_dispatch_method_body=static_dispatch_method_body)) + +- def gen_namespace_function(option, dispatch_tensor, dispatch_options): +- # type: (Any, Optional[str], Any) -> FunctionCode +- if dispatch_tensor: +- option['inferred_type_set'] = 'at::detail::infer_tensor_type_set({})'.format(dispatch_tensor) ++ def gen_namespace_function(option, multidispatch_tensors, dispatch_options): ++ # type: (Any, Optional[List[str]], Any) -> FunctionCode ++ if multidispatch_tensors: ++ option['inferred_type_set'] = ( ++ 'at::detail::multi_dispatch_tensor_type_set({})'.format(', '.join(multidispatch_tensors))) + elif dispatch_options: + option['inferred_type_set'] = '{}.type_set()'.format(dispatch_options['name']) + else: +@@ -1190,6 +1220,9 @@ def create_generic(top_env, declarations): + # Only dispatch via tensor if there is no Options argument + dispatch_tensor = None if dispatch_options else find_dispatch_tensor(formals) + ++ # TODO: Not entirely clear what to do about TensorOptions ++ multidispatch_tensors = None if dispatch_options else find_multidispatch_tensors(formals) ++ + option['type_method_formals'] = [format_formal(f) for f in formals] + option['type_method_actuals'] = [f['name'] for f in formals] + option['native_actuals'] = [f['name'] for f in formals] +@@ -1257,7 +1290,7 @@ def create_generic(top_env, declarations): + + method_of = ['Type'] + if is_method: +- code = gen_tensor_method(option) ++ code = gen_tensor_method(option, multidispatch_tensors) + if is_named_tensor_only: + code = add_namedtensor_enabled_macro(code) + top_env['tensor_method_declarations'].append(code.declaration) +@@ -1265,7 +1298,7 @@ def create_generic(top_env, declarations): + method_of.append('Tensor') + + if is_namespace_function: +- code = gen_namespace_function(option, dispatch_tensor, dispatch_options) ++ code = gen_namespace_function(option, multidispatch_tensors, dispatch_options) + if is_named_tensor_only: + code = add_namedtensor_enabled_macro(code) + top_env['function_definitions'].append(code.definition) +@@ -1304,7 +1337,7 @@ def create_generic(top_env, declarations): + option["schema_string"] = declaration["schema_string"] + try: + if option['mode'] != 'native': +- # XXX: Does the following line do anything meaningful? ++ # Mutably populate option with values + process_option(option) + else: + output_option = process_native(option) +diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp +index 680cbcd9c7..c6ece4dabc 100644 +--- a/aten/src/ATen/native/BinaryOps.cpp ++++ b/aten/src/ATen/native/BinaryOps.cpp +@@ -20,16 +20,6 @@ static constexpr char alpha_mismatch_err[] = + "For integral input tensors, argument alpha must not be a floating point number."; + + Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { +- if (other.is_sparse()) { +- if (self.is_sparse()) { +- at::_sparse_add_out(result, self, other, alpha); +- } else { +- at::_sparse_dense_add_out(result, self, other, alpha); +- } +- return result; +- } else if (self.is_sparse()) { +- AT_ERROR("add(sparse, dense) is not supported. Use add(dense, sparse) instead."); +- } + auto iter = TensorIterator::binary_op(result, self, other, + /*check_mem_overlap=*/true); + TORCH_CHECK(! alpha.isBoolean() || iter.dtype() == ScalarType::Bool, "Boolean alpha only supported for boolean results"); +@@ -41,10 +31,6 @@ Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar + + Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { + Tensor result; +- if (other.is_sparse()) { +- result = at::empty({0}, self.options()); +- return native::add_out(result, self, other, alpha); +- } + auto iter = TensorIterator::binary_op(result, self, other); + TORCH_CHECK(! alpha.isBoolean() || iter.dtype() == ScalarType::Bool, "Boolean alpha only supported for boolean results"); + TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(true), alpha_mismatch_err); +@@ -57,13 +43,6 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { + } + + Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { +- if (self.is_sparse()) { +- if (other.dim() != 0) { +- AT_ERROR("div(): sparse division only supports division by a scalar ", +- "(got shape ", other.sizes(), " for argument 'other')"); +- } +- return at::_sparse_div_zerodim_out(result, self, other); +- } + auto iter = TensorIterator::binary_op(result, self, other, + /*check_mem_overlap=*/true); + div_stub(iter.device_type(), iter); +@@ -72,10 +51,6 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { + + Tensor div(const Tensor& self, const Tensor& other) { + Tensor result; +- if (self.is_sparse()) { +- result = at::empty({0}, self.options()); +- return native::div_out(result, self, other); +- } + auto iter = TensorIterator::binary_op(result, self, other); + div_stub(iter.device_type(), iter); + return iter.output(); +@@ -86,9 +61,6 @@ Tensor& div_(Tensor& self, const Tensor& other) { + } + + Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { +- if (self.is_sparse() || other.is_sparse()) { +- return at::_sparse_mul_out(result, self, other); +- } + auto iter = TensorIterator::binary_op(result, self, other, + /*check_mem_overlap=*/true); + mul_stub(iter.device_type(), iter); +@@ -97,10 +69,6 @@ Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { + + Tensor mul(const Tensor& self, const Tensor& other) { + Tensor result; +- if (self.is_sparse() || other.is_sparse()) { +- result = at::empty({0}, self.options()); +- return native::mul_out(result, self, other); +- } + auto iter = TensorIterator::binary_op(result, self, other); + mul_stub(iter.device_type(), iter); + return iter.output(); +@@ -122,19 +90,6 @@ static inline void sub_check(const Tensor& self, const Tensor& other) { + + Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { + sub_check(self, other); +- if (other.is_sparse()) { +- if (!self.sizes().equals(other.sizes())) { +- AT_ERROR("sizes do not match"); +- } +- if (self.is_sparse()) { +- at::_sparse_add_out(result, self, other, -alpha); +- } else { +- at::_sparse_dense_add_out(result, self, other, -alpha); +- } +- return result; +- } else if (self.is_sparse()) { +- AT_ERROR("sub(sparse, dense) is not supported. Use sub(dense, sparse) instead."); +- } + auto iter = TensorIterator::binary_op(result, self, other, + /*check_mem_overlap=*/true); + TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(false), alpha_mismatch_err); +@@ -146,10 +101,6 @@ Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar + Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) { + sub_check(self, other); + Tensor result; +- if (other.is_sparse()) { +- result = at::empty({0}, self.options()); +- return native::sub_out(result, self, other, alpha); +- } + auto iter = TensorIterator::binary_op(result, self, other); + TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(false), alpha_mismatch_err); + sub_stub(iter.device_type(), iter, alpha); +@@ -197,12 +148,18 @@ Tensor& add_(Tensor& self, Scalar other, Scalar alpha) { + return native::add_(self, wrapped_scalar_tensor(other), alpha); + } + ++// WARNING: There doesn't appear to be any testing for this function ++// with sparse self input. + Tensor div(const Tensor& self, Scalar other) { +- return native::div(self, wrapped_scalar_tensor(other)); ++ return self.div(wrapped_scalar_tensor(other)); // redispatch! + } + ++// WARNING: This function, with a sparse self, is currently only ++// exercised by DistributedDataParallelTest.test_sparse_gradients ++// (you need to exercise it from C++, because this overload is never ++// used for Python) + Tensor& div_(Tensor& self, Scalar other) { +- return native::div_(self, wrapped_scalar_tensor(other)); ++ return self.div_(wrapped_scalar_tensor(other)); // redispatch! + } + + Tensor mul(const Tensor& self, Scalar other) { +diff --git a/aten/src/ATen/native/LegacyBridge.cpp b/aten/src/ATen/native/LegacyBridge.cpp +index 3e544cb14d..e69de29bb2 100644 +--- a/aten/src/ATen/native/LegacyBridge.cpp ++++ b/aten/src/ATen/native/LegacyBridge.cpp +@@ -1,79 +0,0 @@ +-#include +-#include +-#include +- +-namespace at { namespace native { +- +-// Note [Multiple dispatch to sparse] +-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +-// In an ideal world, we would use direct support for multiple dispatch to +-// say that add(Dense, Dense) should dispatch to one function, while +-// add(Dense, Sparse) should dispatch to another function. +-// +-// In a world where we only have single dispatch, we can single dispatch on +-// the first function, and then do an is_sparse() test on the second argument +-// to direct ourselves to the correct argument. +-// +-// We are in neither of those worlds. Instead, we have a _th_addmm function +-// which has legacy implementations in the single dispatch world, BUT our +-// actual addmm function needs to call s_native_addmm if the function *would have* +-// utilized a sparse kernel that is natively implemented. +-// +-// _th_addmm is "good old single dispatch" which internally handles the is_sparse() +-// test and also handles broadcasting. s_native_addmm works asymmetrically: +-// it doesn't handle broadcasting at all, and it ASSUMES that the relevant +-// argument is a sparse tensor. Why the asymmetry? It turns out it is not +-// so easy to figure out if a kernel is implemented in THS; it's not as simple +-// as testing if the first argument is sparse, because, e.g., +-// in addmm(Dense, Sparse), the sparse kernel is in the second argument. So, +-// the trampoline function is going to know about the overloads *anyway*; it +-// might as well also handle is_sparse() and broadcasting while it's at it. +-// +-// Why not change TH to follow this new scheme? We could... but since it's +-// all going away when we finish porting the TH functions to ATen, we haven't +-// done it. +- +-// NB: You may be tempted to implement addmm and addmm_ just as calls to addmm_out, but +-// calling the actual implementing function matters, because broadcast +-// will be handled differently depending on if you call addmm_ or (a seemingly +-// equivalent) add_out. Arguably this mismatch in treatment is a bug, +-// c.f., https://github.com/pytorch/pytorch/issues/8308 but fixing this +-// bug would involve changing a lot of other places, so we leave it +-// alone for now. +- +-Tensor& addmm_out(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { +- // See Note [Multiple dispatch to sparse] +- auto mat1_sparse = mat1.is_sparse(); +- if (mat1_sparse) { +- Tensor b_self; +- std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); +- return s_native_addmm_out(result, b_self, mat1, mat2, beta, alpha); +- } else { +- return at::_addmm_out(result, self, mat1, mat2, beta, alpha); +- } +-} +- +-Tensor addmm(const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { +- // See Note [Multiple dispatch to sparse] +- auto mat1_sparse = mat1.is_sparse(); +- if (mat1_sparse) { +- Tensor b_self; +- std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm"); +- return s_native_addmm(b_self, mat1, mat2, beta, alpha); +- } else { +- return at::_addmm(self, mat1, mat2, beta, alpha); +- } +-} +- +-Tensor& addmm_(Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { +- // See Note [Multiple dispatch to sparse] +- auto mat1_sparse = mat1.is_sparse(); +- if (mat1_sparse) { +- // inplace is not broadcasting +- return s_native_addmm_(self, mat1, mat2, beta, alpha); +- } else { +- return at::_addmm_(self, mat1, mat2, beta, alpha); +- } +-} +- +-}} // namespace at::native +diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml +index a7b5981144..18b2c97e8c 100644 +--- a/aten/src/ATen/native/native_functions.yaml ++++ b/aten/src/ATen/native/native_functions.yaml +@@ -171,8 +171,8 @@ + dispatch: + CPU: add + CUDA: add +- SparseCPU: add +- SparseCUDA: add ++ SparseCPU: add_sparse ++ SparseCUDA: add_sparse + MkldnnCPU: mkldnn_add + named_guard: False + +@@ -181,8 +181,8 @@ + dispatch: + CPU: add_ + CUDA: add_ +- SparseCPU: add_ +- SparseCUDA: add_ ++ SparseCPU: add_sparse_ ++ SparseCUDA: add_sparse_ + MkldnnCPU: mkldnn_add_ + named_guard: False + +@@ -190,8 +190,8 @@ + dispatch: + CPU: add_out + CUDA: add_out +- SparseCPU: add_out +- SparseCUDA: add_out ++ SparseCPU: add_out_sparse_cpu ++ SparseCUDA: add_out_sparse_cuda + MkldnnCPU: mkldnn_add_out + named_guard: False + +@@ -734,13 +734,28 @@ + + - func: div.Tensor(Tensor self, Tensor other) -> Tensor + variants: function, method ++ dispatch: ++ CPU: div ++ CUDA: div ++ SparseCPU: div_sparse ++ SparseCUDA: div_sparse + named_guard: False + + - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method ++ dispatch: ++ CPU: div_ ++ CUDA: div_ ++ SparseCPU: div_sparse_ ++ SparseCUDA: div_sparse_ + named_guard: False + + - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) ++ dispatch: ++ CPU: div_out ++ CUDA: div_out ++ SparseCPU: div_out_sparse_zerodim ++ SparseCUDA: div_out_sparse_zerodim + named_guard: False + + # For C++ only, until we have conversion from C++ numbers to Tensor +@@ -1555,19 +1570,18 @@ + dispatch: + CPU: mul + CUDA: mul +- SparseCPU: mul +- SparseCUDA: mul ++ SparseCPU: mul_sparse ++ SparseCUDA: mul_sparse + MkldnnCPU: mkldnn_mul + named_guard: False + +- + - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + dispatch: + CPU: mul_ + CUDA: mul_ +- SparseCPU: mul_ +- SparseCUDA: mul_ ++ SparseCPU: mul_sparse_ ++ SparseCUDA: mul_sparse_ + MkldnnCPU: mkldnn_mul_ + named_guard: False + +@@ -1575,8 +1589,8 @@ + dispatch: + CPU: mul_out + CUDA: mul_out +- SparseCPU: mul_out +- SparseCUDA: mul_out ++ SparseCPU: mul_out_sparse_cpu ++ SparseCUDA: mul_out_sparse_cuda + MkldnnCPU: mkldnn_mul_out + named_guard: False + +@@ -2085,41 +2099,6 @@ + CPU: softmax_backward_cpu + CUDA: softmax_backward_cuda + +-- func: _sparse_add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- SparseCPU: add_out_sparse_cpu +- SparseCUDA: add_out_sparse_cuda +- +-- func: _sparse_dense_add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- CPU: add_out_dense_sparse_cpu +- CUDA: add_out_dense_sparse_cuda +- +-- func: _sparse_div_zerodim.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- SparseCPU: div_out_sparse_zerodim +- SparseCUDA: div_out_sparse_zerodim +- +-- func: _sparse_div_scalar.out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- SparseCPU: div_out_sparse_scalar +- SparseCUDA: div_out_sparse_scalar +- +-- func: _sparse_mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- SparseCPU: mul_out_sparse_cpu +- SparseCUDA: mul_out_sparse_cuda +- +-- func: _sparse_mul_zerodim.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- SparseCPU: mul_out_sparse_zerodim +- SparseCUDA: mul_out_sparse_zerodim +- +-- func: _sparse_mul_scalar.out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- SparseCPU: mul_out_sparse_scalar +- SparseCUDA: mul_out_sparse_scalar +- + - func: split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[] + variants: function, method + device_guard: False +@@ -2671,14 +2650,29 @@ + MkldnnCPU: mkldnn_zero_ + + - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) ++ dispatch: ++ CPU: sub_out ++ CUDA: sub_out ++ SparseCPU: sub_out_sparse ++ SparseCUDA: sub_out_sparse + named_guard: False + + - func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + variants: function, method ++ dispatch: ++ CPU: sub ++ CUDA: sub ++ SparseCPU: sub_sparse ++ SparseCUDA: sub_sparse + named_guard: False + + - func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + variants: method ++ dispatch: ++ CPU: sub_ ++ CUDA: sub_ ++ SparseCPU: sub_sparse_ ++ SparseCUDA: sub_sparse_ + named_guard: False + + # For C++ only, until we have conversion from C++ numbers to Tensor +@@ -2699,32 +2693,37 @@ + variants: function + named_guard: False + +-- func: s_native_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- CPU: s_addmm_out_sparse_dense_cpu +- CUDA: s_addmm_out_sparse_dense_cuda +- +-- func: s_native_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +- dispatch: +- CPU: s_addmm_sparse_dense_cpu +- CUDA: s_addmm_sparse_dense_cuda +- +-- func: s_native_addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +- dispatch: +- CPU: s_addmm_sparse_dense_cpu_ +- CUDA: s_addmm_sparse_dense_cuda_ +- ++# Functionally the same as addmm, but we give it a different derivative formula ++# that doesn't propagate gradients to non-present entries on sparse. + - func: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor ++ named_guard: False + + - func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) ++ dispatch: ++ CPU: legacy::cpu::_th_addmm_out ++ CUDA: legacy::cuda::_th_addmm_out ++ SparseCPU: addmm_out_sparse_dense_cpu ++ SparseCUDA: addmm_out_sparse_dense_cuda + named_guard: False + + - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + variants: function, method ++ dispatch: ++ CPU: legacy::cpu::_th_addmm ++ CUDA: legacy::cuda::_th_addmm ++ SparseCPU: addmm_sparse_dense_cpu ++ SparseCUDA: addmm_sparse_dense_cuda + named_guard: False + + - func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + variants: method ++ dispatch: ++ CPU: legacy::cpu::_th_addmm_ ++ CUDA: legacy::cuda::_th_addmm_ ++ # Warning! For whatever reason, the inplace sparse addmm is NON ++ # broadcasting ++ SparseCPU: s_addmm_sparse_dense_cpu_ ++ SparseCUDA: s_addmm_sparse_dense_cuda_ + named_guard: False + + +@@ -2884,8 +2883,8 @@ + - func: sparse_mask(Tensor self, Tensor mask) -> Tensor + variants: method + dispatch: +- CPU: sparse_mask_cpu +- CUDA: sparse_mask_cuda ++ SparseCPU: sparse_mask_cpu ++ SparseCUDA: sparse_mask_cuda + requires_tensor: True + + +@@ -4556,24 +4555,6 @@ + CUDA: legacy::cuda::_th_std + named_guard: False + +-- func: _addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +- dispatch: +- CPU: legacy::cpu::_th_addmm_out +- CUDA: legacy::cuda::_th_addmm_out +- named_guard: False +- +-- func: _addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +- dispatch: +- CPU: legacy::cpu::_th_addmm +- CUDA: legacy::cuda::_th_addmm +- named_guard: False +- +-- func: _addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +- dispatch: +- CPU: legacy::cpu::_th_addmm_ +- CUDA: legacy::cuda::_th_addmm_ +- named_guard: False +- + - func: _cat(Tensor[] tensors, int dim=0) -> Tensor + dispatch: + CPU: legacy::cpu::_th_cat +diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp +index f9cbe6c96c..b01166245a 100644 +--- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp ++++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp +@@ -1,3 +1,5 @@ ++#include ++ + #include + #include + #include +@@ -148,10 +150,23 @@ SparseTensor pow_sparse_scalar(const SparseTensor& t, Scalar value) { + // div(SparseTensor, Scalar) + // -------------------------------------------------------------------- + ++SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value); ++ ++Tensor div_sparse(const Tensor& self, const Tensor& value) { ++ Tensor result = at::empty({0}, self.options()); ++ return div_out_sparse_zerodim(result, self, value); ++} ++ ++Tensor& div_sparse_(Tensor& self, const Tensor& value) { ++ return div_out_sparse_zerodim(self, self, value); ++} ++ + SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) { ++ TORCH_CHECK(value.dim() == 0, "sparse division only supports division by a scalar (got shape ", ++ value.sizes(), " for argument 'other')"); ++ + AT_ASSERT(r.is_sparse()); + AT_ASSERT(t.is_sparse()); +- AT_ASSERT(value.dim() == 0); + + if (is_same_tensor(r, t)) { + r._values().div_(value); +@@ -187,9 +202,40 @@ Tensor norm_sparse(const SparseTensor& self, Scalar value) { + // add(SparseTensor, SparseTensor, Scalar) [broadcasts] + // -------------------------------------------------------------------- + ++Tensor add_sparse(const Tensor& self, const Tensor& other, Scalar alpha) { ++ // TODO: Why?! Can't we just flip the order here... ++ TORCH_CHECK(!(self.is_sparse() && !other.is_sparse()), ++ "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); ++ Tensor result = at::empty({0}, self.options()); ++ return at::add_out(result, self, other, alpha); // redispatch! ++} ++ ++Tensor& add_sparse_(Tensor& self, const Tensor& other, Scalar alpha) { ++ return at::add_out(self, self, other, alpha); // redispatch! ++} ++ ++// There's actually nothing sparse specific about these implementations ++ ++Tensor sub_sparse(const Tensor& self, const Tensor& other, Scalar alpha) { ++ return native::add_sparse(self, other, -alpha); ++} ++ ++Tensor& sub_sparse_(Tensor& self, const Tensor& other, Scalar alpha) { ++ return native::add_sparse_(self, other, -alpha); ++} ++ ++Tensor& sub_out_sparse(Tensor& r, const Tensor& self, const Tensor& other, Scalar alpha) { ++ return at::add_out(r, self, other, -alpha); // redispatch! ++} ++ ++Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, Scalar value); ++ + SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, Scalar value) { +- AT_ASSERT(r.is_sparse()); +- AT_ASSERT(t.is_sparse()); ++ if (!t.is_sparse()) { ++ return add_out_dense_sparse_cpu(r, t, src, value); ++ } ++ // TODO: This test seems a bit goofy ++ TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); + AT_ASSERT(!t.is_cuda()); // the dispatch argument + TORCH_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor"); + TORCH_CHECK(!src.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor"); +@@ -375,6 +421,15 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTen + // mul(SparseTensor, SparseTensor) [broadcasts] + // -------------------------------------------------------------------- + ++Tensor mul_sparse(const Tensor& self, const Tensor& other) { ++ Tensor result = at::empty({0}, self.options()); ++ return at::mul_out(result, self, other); // redispatch! ++} ++ ++Tensor& mul_sparse_(Tensor& self, const Tensor& other) { ++ return at::mul_out(self, self, other); // redispatch! ++} ++ + SparseTensor& mul_out_sparse_cpu(SparseTensor& r, const Tensor& t_, const Tensor& src_) { + if (src_.dim() == 0) { + return mul_out_sparse_zerodim(r, t_, src_); +@@ -576,6 +631,19 @@ Tensor& s_addmm_out_sparse_dense_cpu( + + } + ++Tensor& addmm_out_sparse_dense_cpu( ++ Tensor& result, ++ const Tensor& self, ++ const SparseTensor& mat1, ++ const Tensor& mat2, ++ Scalar beta, ++ Scalar alpha ++) { ++ Tensor b_self; ++ std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); ++ return s_addmm_out_sparse_dense_cpu(result, b_self, mat1, mat2, beta, alpha); ++} ++ + Tensor s_addmm_sparse_dense_cpu( + const Tensor& t, + const SparseTensor& sparse, +@@ -588,6 +656,18 @@ Tensor s_addmm_sparse_dense_cpu( + return r; + } + ++Tensor addmm_sparse_dense_cpu( ++ const Tensor& self, ++ const SparseTensor& mat1, ++ const Tensor& mat2, ++ Scalar beta, ++ Scalar alpha ++) { ++ Tensor b_self; ++ std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); ++ return s_addmm_sparse_dense_cpu(b_self, mat1, mat2, beta, alpha); ++} ++ + Tensor& s_addmm_sparse_dense_cpu_( + Tensor& t, + const SparseTensor& sparse, +@@ -598,6 +678,8 @@ Tensor& s_addmm_sparse_dense_cpu_( + return s_addmm_out_sparse_dense_cpu(t, t, sparse, dense, beta, alpha); + } + ++// NB: Purposely no broadcasting version of addmm inplace ++ + Tensor _sparse_addmm( + const Tensor& t, + const SparseTensor& sparse, +@@ -605,9 +687,10 @@ Tensor _sparse_addmm( + Scalar beta, + Scalar alpha + ) { +- Tensor b_t; +- std::tie(b_t) = expand_size(t, {sparse.size(0), dense.size(1)}, "addmm"); +- return at::s_native_addmm(b_t, sparse, dense, beta, alpha); ++ // _sparse_addmm forward is functionally equivalent to addmm; it's ++ // just the backward that is different. This technically does an ++ // unnecessary redispatch, I was too lazy to make it not do that ++ return at::addmm(t, sparse, dense, beta, alpha); + } + + Tensor _sparse_mm( +@@ -615,16 +698,19 @@ Tensor _sparse_mm( + const Tensor& dense + ) { + Tensor t = at::zeros({}, dense.options()); +- return at::_sparse_addmm(t, sparse, dense, 0, 1); ++ return at::_sparse_addmm(t, sparse, dense, 0, 1); // redispatch! + } + ++// NB: Despite its suggestive name, this actually only exists so that ++// we can redispatch to addmm_out; this is NOT an implementation of ++// the sparse masking version of mm + SparseTensor& _sparse_mm_out( + SparseTensor& result, + const SparseTensor& sparse, + const Tensor& dense + ) { + Tensor t = at::zeros({}, dense.options()); +- return at::addmm_out(result, t, sparse, dense, 0, 1); ++ return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch! + } + + // -------------------------------------------------------------------- +diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.h b/aten/src/ATen/native/sparse/SparseTensorMath.h +new file mode 100644 +index 0000000000..514f84fd8e +--- /dev/null ++++ b/aten/src/ATen/native/sparse/SparseTensorMath.h +@@ -0,0 +1,11 @@ ++#pragma once ++ ++#include ++#include ++ ++namespace at { namespace native { ++ ++sparse::SparseTensor& mul_out_sparse_scalar(sparse::SparseTensor& r, const sparse::SparseTensor& t, Scalar value); ++sparse::SparseTensor& mul_out_sparse_zerodim(sparse::SparseTensor& r, const sparse::SparseTensor& t, const Tensor& value); ++ ++}} +diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp +index 927d91797d..4d43150547 100644 +--- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp ++++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp +@@ -12,7 +12,7 @@ SparseTensor& sparse_mask_out_cuda(SparseTensor& r, const Tensor& t, const Spars + TORCH_CHECK(mask.is_coalesced(), "sparse_mask: mask is uncoalesced"); + TORCH_CHECK(mask.sizes().equals(t.sizes()), "sparse_mask: operands have incompatible sizes; self has size ", + t.sizes(), " but mask has size ", mask.sizes()); +- AT_ASSERT(t.is_cuda()); // dispatch argument ++ TORCH_CHECK(t.is_cuda(), "sparse_mask: expected 'self' to be CUDA, but got CPU"); + TORCH_CHECK(mask.is_cuda(), "sparse_mask: expected 'mask' to be CUDA, but got CPU"); + TORCH_CHECK(r.is_cuda(), "sparse_mask: expected 'out' to be CUDA, but got CPU"); + TORCH_CHECK(cuda::check_device({r, t, mask}), +diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +index fdc17e4dfd..f681f16ce4 100644 +--- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu ++++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +@@ -2,12 +2,14 @@ + #include + #include + #include ++#include + #include + #include + #include + #include + #include + #include ++#include + + #include + #include +@@ -51,7 +53,7 @@ namespace { + // -------------------------------------------------------------------- + + Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseTensor& sparse_, const Tensor& dense, Scalar beta, Scalar alpha) { +- AT_ASSERT(t.is_cuda()); // dispatch argument ++ TORCH_CHECK(t.is_cuda(), "addmm: expected 'self' to be CUDA, but got CPU"); + TORCH_CHECK(r_.is_cuda(), "addmm: expected 'out' to be CUDA, but got CPU"); + TORCH_CHECK(sparse_.is_cuda(), "addmm: expected 'mat1' to be CUDA, but got CPU"); + TORCH_CHECK(dense.is_cuda(), "addmm: expected 'mat2' to be CUDA, but got CPU"); +@@ -151,6 +153,19 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT + return r_; + } + ++Tensor& addmm_out_sparse_dense_cuda( ++ Tensor& result, ++ const Tensor& self, ++ const SparseTensor& mat1, ++ const Tensor& mat2, ++ Scalar beta, ++ Scalar alpha ++) { ++ Tensor b_self; ++ std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); ++ return s_addmm_out_sparse_dense_cuda(result, b_self, mat1, mat2, beta, alpha); ++} ++ + Tensor s_addmm_sparse_dense_cuda( + const Tensor& t, + const SparseTensor& sparse, +@@ -163,6 +178,18 @@ Tensor s_addmm_sparse_dense_cuda( + return r; + } + ++Tensor addmm_sparse_dense_cuda( ++ const Tensor& self, ++ const SparseTensor& mat1, ++ const Tensor& mat2, ++ Scalar beta, ++ Scalar alpha ++) { ++ Tensor b_self; ++ std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); ++ return s_addmm_sparse_dense_cuda(b_self, mat1, mat2, beta, alpha); ++} ++ + Tensor& s_addmm_sparse_dense_cuda_( + Tensor& t, + const SparseTensor& sparse, +@@ -173,6 +200,8 @@ Tensor& s_addmm_sparse_dense_cuda_( + return s_addmm_out_sparse_dense_cuda(t, t, sparse, dense, beta, alpha); + } + ++// NB: Purposely no broadcasting version of addmm inplace ++ + // Deleted sspaddmm (sparse, dense) -> sparse + + // -------------------------------------------------------------------- +@@ -180,7 +209,7 @@ Tensor& s_addmm_sparse_dense_cuda_( + // -------------------------------------------------------------------- + + SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse_, const Tensor& dense/* , Scalar alpha */) { +- AT_ASSERT(sparse_.is_cuda()); // dispatch argument ++ TORCH_CHECK(sparse_.is_cuda(), "hspmm: expected 'self' to be CUDA, but got CPU"); + TORCH_CHECK(r_.is_cuda(), "hspmm: expected 'out' to be CUDA, but got CPU"); + TORCH_CHECK(dense.is_cuda(), "hspmm: expected 'mat2' to be CUDA, but got CPU"); + +@@ -249,9 +278,9 @@ SparseTensor hspmm_sparse_cuda(const SparseTensor& sparse, const Tensor& dense) + // -------------------------------------------------------------------- + + Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseTensor& sparse, at::Scalar value) { +- AT_ASSERT(dense.is_cuda()); // dispatch argument +- TORCH_CHECK(sparse.is_cuda(), "add: expected 'other' to be CUDA, but got CPU"); +- TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU"); ++ TORCH_CHECK(dense.is_cuda(), "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"); ++ TORCH_CHECK(sparse.is_cuda(), "add: expected 'other' to be a CUDA tensor, but got a CPU tensor"); ++ TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be a CUDA tensor, but got a CPU tensor"); + + TORCH_CHECK(cuda::check_device({sparse, r_, dense})); + +@@ -350,8 +379,17 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT + // add(SparseTensor, SparseTensor, Scalar) [broadcasts] + // -------------------------------------------------------------------- + ++Tensor& add_out_dense_sparse_cuda(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, Scalar value); ++ + SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const SparseTensor& src, Scalar value) { +- AT_ASSERT(t.is_cuda()); // dispatch argument ++ if (!t.is_sparse()) { ++ return add_out_dense_sparse_cuda(r_, t, src, value); ++ } ++ ++ // TODO: This test seems a bit goofy ++ TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); ++ ++ TORCH_CHECK(t.is_cuda(), "add: expected 'self' to be CUDA, but got CPU"); + TORCH_CHECK(src.is_cuda(), "add: expected 'other' to be CUDA, but got CPU"); + TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU"); + +@@ -410,7 +448,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons + return mul_out_sparse_zerodim(r_, src_, t_); + } + +- AT_ASSERT(t_.is_cuda()); // dispatch argument ++ TORCH_CHECK(t_.is_cuda(), "mul: expected 'self' to be CUDA, but got CPU"); + TORCH_CHECK(src_.is_cuda(), "mul: expected 'other' to be CUDA, but got CPU"); + TORCH_CHECK(r_.is_cuda(), "mul: expected 'out' to be CUDA, but got CPU"); + TORCH_CHECK(cuda::check_device({r_, t_, src_})); +diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h +index 03a272f12c..40c251e69d 100644 +--- a/aten/src/ATen/templates/TensorBody.h ++++ b/aten/src/ATen/templates/TensorBody.h +@@ -419,7 +419,7 @@ protected: + }; + + namespace detail { +-// Helper creator for Tensor clas which doesn't requires the users to pass ++// Helper creator for Tensor class which doesn't requires the users to pass + // in an intrusive_ptr instead it just converts the argument passed to + // requested intrusive_ptr type. + template +@@ -427,15 +427,6 @@ Tensor make_tensor(Args&&... args) { + return Tensor(c10::make_intrusive(std::forward(args)...)); + } + +-inline TensorTypeSet infer_tensor_type_set(const Tensor & tl) { +- return tl.type_set(); +-} +- +-inline TensorTypeSet infer_tensor_type_set(TensorList tl) { +- TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); +- return tl[0].type_set(); +-} +- + } // namespace detail + + static inline TensorTypeId legacyExtractTypeId(const Tensor& t) { +diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h +index 0afa00c19a..c5cae10e4f 100644 +--- a/aten/src/ATen/templates/TensorMethods.h ++++ b/aten/src/ATen/templates/TensorMethods.h +@@ -11,6 +11,7 @@ + #if !defined(CAFFE2_IS_XPLAT_BUILD) + #include + #endif ++#include + #include + #ifdef USE_STATIC_DISPATCH + #include +@@ -21,6 +22,27 @@ + + namespace at { + ++namespace detail { ++ ++struct MultiDispatchTensorTypeSet : IterArgs { ++ TensorTypeSet ts; ++ void operator()(const at::Tensor& x) { ++ ts = ts | x.type_set(); ++ } ++ void operator()(at::ArrayRef xs) { ++ for (const auto& x : xs) { ++ ts = ts | x.type_set(); ++ } ++ } ++}; ++ ++template ++TensorTypeSet multi_dispatch_tensor_type_set(Args&&... args) { ++ return MultiDispatchTensorTypeSet().apply(std::forward(args)...).ts; ++} ++ ++} ++ + struct Quantizer; + // This is temporary typedef to enable Quantizer in aten native function API + // we'll remove them when we are actually exposing Quantizer class +diff --git a/c10/core/TensorTypeId.h b/c10/core/TensorTypeId.h +index eb99881296..d01ee9d5f3 100644 +--- a/c10/core/TensorTypeId.h ++++ b/c10/core/TensorTypeId.h +@@ -25,8 +25,6 @@ enum class TensorTypeId : uint8_t { + // the hierarchy for convenience and performance + CPUTensorId, // PyTorch/Caffe2 supported + CUDATensorId, // PyTorch/Caffe2 supported +- SparseCPUTensorId, // PyTorch only +- SparseCUDATensorId, // PyTorch only + MKLDNNTensorId, // Caffe2 only + OpenGLTensorId, // Caffe2 only + OpenCLTensorId, // Caffe2 only +@@ -40,6 +38,10 @@ enum class TensorTypeId : uint8_t { + ComplexCPUTensorId, // PyTorch only + ComplexCUDATensorId, // PyTorch only + ++ // Sparse has multi-dispatch with dense; handle it first ++ SparseCPUTensorId, // PyTorch only ++ SparseCUDATensorId, // PyTorch only ++ + // WARNING! If you add more "wrapper" style tensor ids (tensor + // ids which don't get kernels directly defined in native_functions.yaml; + // examples are tracing or profiling) here, you need to also adjust +diff --git a/test/test_nn.py b/test/test_nn.py +index e5d28d3253..e5d388f798 100644 +--- a/test/test_nn.py ++++ b/test/test_nn.py +@@ -1804,7 +1804,7 @@ class TestNN(NNTestCase): + # Without using `torch.no_grad()`, this will leak CUDA memory. + # (Issue is filed at https://github.com/pytorch/pytorch/issues/21875) + mw[0][0] = 5 +- with self.assertRaisesRegex(RuntimeError, "Expected object of backend CPU but got backend CUDA"): ++ with self.assertRaisesRegex(RuntimeError, "Expected object of backend CUDA but got backend CPU"): + mw[0][0] == mw._base[0][0] + + try: +@@ -2958,6 +2958,7 @@ class TestNN(NNTestCase): + x = torch.tensor([], device=device, dtype=torch.long) + for sparse in [True, False]: + Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse) ++ Embed.to(device) + + output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=torch.long)) + self.assertEqual(output, torch.zeros_like(output)) +diff --git a/test/test_sparse.py b/test/test_sparse.py +index 1243103e6e..f7795c6804 100644 +--- a/test/test_sparse.py ++++ b/test/test_sparse.py +@@ -2107,21 +2107,21 @@ class TestSparseOneOff(TestCase): + sparse_y = torch.cuda.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), + torch.randn(4, 4, 4).cuda(), + [3, 4, 4]) +- with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): ++ with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): + x + sparse_y + + x = torch.zeros(3, 4, 4, 0) + sparse_y = torch.cuda.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), + torch.randn(4, 4, 4, 0).cuda(), + [3, 4, 4, 0]) +- with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): ++ with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): + x + sparse_y + + x = torch.zeros(0, 4, 4, 0) + sparse_y = torch.cuda.sparse.FloatTensor(torch.LongTensor(1, 0).cuda(), + torch.randn(0, 4, 4, 0).cuda(), + [0, 4, 4, 0]) +- with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): ++ with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): + x + sparse_y + + +diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp +index cfc362d083..358d29dde4 100644 +--- a/tools/autograd/templates/Functions.cpp ++++ b/tools/autograd/templates/Functions.cpp +@@ -528,7 +528,7 @@ Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef si + int64_t out_cols = grad.size(1); + Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true); + Tensor r = at::empty({out_cols, out_rows}, grad.options()).t(); +- at::s_native_addmm_out(r, t, mat1.t(), grad, alpha, 1); ++ at::addmm_out(r, t, mat1.t(), grad, alpha, 1); + return r; + } + return maybe_multiply(grad.t().mm(mat1).t(), alpha); +diff --git a/torch/csrc/utils/variadic.h b/torch/csrc/utils/variadic.h +index 3a924a9db9..63f34afbc3 100644 +--- a/torch/csrc/utils/variadic.h ++++ b/torch/csrc/utils/variadic.h +@@ -2,6 +2,7 @@ + + #include + #include ++#include + + #include + #include +@@ -10,67 +11,7 @@ + + namespace torch { + +-// This class allows you to write variadic functions which +-// call a (possibly overloaded) function on each argument, +-// in order. This is most commonly used in autogenerated code, +-// where it is convenient to have a function that can uniformly +-// take arguments of different types. If your arguments +-// are homogenous consider using a std::initializer_list instead. +-template +-struct IterArgs { +- template +- inline F& apply() { +- return self(); +- } +- +- // NB: Use perfect forwarding here, otherwise we'll make value +- // copies of all arguments! +- template +- inline F& apply(T&& arg, Args&&... args) { +- self()(std::forward(arg)); +- if (self().short_circuit()) { +- return self(); +- } else { +- return apply(std::forward(args)...); +- } +- } +- +- // Here are some handy overloads which provide sensible +- // defaults for container-like structures that one might +- // be interested in recursing into. You can enable them +- // by adding: +- // +- // using IterArgs::operator() +- // +- // to your struct. These are not enabled by default because +- // you may be able to process these structures more efficiently +- // than handling them one-by-one. +- +- template +- void operator()(at::ArrayRef args) { +- for (const auto& arg : args) { +- self()(arg); +- if (short_circuit()) +- return; +- } +- } +- +- // NB: we need to specify std::vector manually as C++ won't +- // do an implicit conversion to make a template deduction go through. +- template +- void operator()(const std::vector& args) { +- self()(at::ArrayRef{args}); +- } +- +- bool short_circuit() { +- return false; +- } +- +- private: +- inline F& self() { +- return *static_cast(this); +- } +-}; ++using at::IterArgs; + + struct CountTensors : IterArgs { + size_t out = 0; +@@ -194,4 +135,5 @@ template ) { + return ReturnType(function(accessor.template operator()(Is)...)); + } ++ + } // namespace torch +diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp +index 18c3ee45ef..9a032dca2a 100644 +--- a/torch/lib/c10d/ProcessGroupGloo.cpp ++++ b/torch/lib/c10d/ProcessGroupGloo.cpp +@@ -792,6 +792,14 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { + // we run allgather on the nnz, and then allgather with max(nnz). + // We could use an allgatherv for this, if it were available. + at::Tensor allreduce(std::vector& tensors) { ++ // TODO: This is a massive hack! There is some confusion about ++ // Variable/Tensor inside the body of this function. Turning off ++ // grad smooths over the confusion for now. This fixes ++ // test/test_c10d.py ProcessGroupGlooTest.test_sparse_allreduce_basics ++ // ++ // The correct fix is to stop allocating tensors that are not variables, ++ // but to conveniently do this c10d must depend on torch not ATen ++ at::AutoNonVariableTypeMode _no_grad(true); + auto input = tensors[0]; + + // Perform local reduction if we have multiple inputs. +-- +2.13.5 + From 3997b6e1f38eb24cf87173e59ab25029da561316 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 13 Sep 2019 13:03:40 -0400 Subject: [PATCH 3/5] bugfix Signed-off-by: Edward Z. Yang --- torch_xla/csrc/aten_xla_type.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 02f0564e2f20..00fe694fc989 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -870,14 +870,11 @@ at::Tensor& AtenXlaType::copy_(at::Tensor& self, const at::Tensor& src, c10::optional src_tensor = bridge::TryGetXlaTensor(src); if (!src_tensor) { - TORCH_ASSERT(self_tensor); + XLA_CHECK(self_tensor); self_tensor.SetTensor(CopyTensor(src, self_tensor->scalar_type())); } else if (!self_tensor) { // TODO: Is self_tensor good enough? I don't think so... therefore // the hack below: - // - // Do not mark the tensor creation as writeable to not discard the XLA tensor - // device context, but make a copy to avoid core data to be shared. std::vector tensors = {src}; auto xla_tensors = bridge::XlaCreateTensorList(tensors); // Hack in an overwrite of a const tensor. From 6599c9c9bf3b8fe90574cd65a9044879eaddf230 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 13 Sep 2019 13:49:12 -0400 Subject: [PATCH 4/5] bugfixes Signed-off-by: Edward Z. Yang --- torch_xla/csrc/aten_xla_type.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 00fe694fc989..1df54a668542 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -871,7 +871,7 @@ at::Tensor& AtenXlaType::copy_(at::Tensor& self, const at::Tensor& src, if (!src_tensor) { XLA_CHECK(self_tensor); - self_tensor.SetTensor(CopyTensor(src, self_tensor->scalar_type())); + self_tensor->SetTensor(CopyTensor(src, self.scalar_type())); } else if (!self_tensor) { // TODO: Is self_tensor good enough? I don't think so... therefore // the hack below: @@ -882,7 +882,7 @@ at::Tensor& AtenXlaType::copy_(at::Tensor& self, const at::Tensor& src, const_cast(self).unsafeGetTensorImpl()->shallow_copy_from( t.getIntrusivePtr()); } else { - XLATensor::copy_(self_tensor, *src_tensor); + XLATensor::copy_(*self_tensor, *src_tensor); } return self; } From aefaf81015f994af1a7772e25c939f0622d49075 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 13 Sep 2019 14:52:40 -0400 Subject: [PATCH 5/5] Remove patch Signed-off-by: Edward Z. Yang --- torch_patches/25653.diff | 5558 -------------------------------------- 1 file changed, 5558 deletions(-) delete mode 100644 torch_patches/25653.diff diff --git a/torch_patches/25653.diff b/torch_patches/25653.diff deleted file mode 100644 index 40a277fbeebc..000000000000 --- a/torch_patches/25653.diff +++ /dev/null @@ -1,5558 +0,0 @@ -From c0f199b932527a79f9f48674c7345b789a37967a Mon Sep 17 00:00:00 2001 -From: "Edward Z. Yang" -Date: Fri, 13 Sep 2019 11:05:42 -0400 -Subject: [PATCH] Implement multiple dispatch - -Instead of considering only the TensorTypeSet of the first argument, -we collect all Tensor and TensorList arguments and union them together -before computing the dispatch type id. - -A minor bit of refactoring I had to do to get here was move the IterArgs -functionality in torch/csrc/utils/variadic.h into ATen/core. There's -some refactoring due on that file too (it has copies of some C++ helper -pieces which already live in c10). - -There is a little bit of a hack in the code generator to turn 'self' -arguments into '*this'. I think this may be duplicated with some -logic somewhere else but I have to double check. - -Signed-off-by: Edward Z. Yang - -ghstack-source-id: 18d37ab48c63a7ee04ce19ecd3b803ba7096afca -Pull Request resolved: https://github.com/pytorch/pytorch/pull/25653 ---- - aten/src/ATen/SparseTensorUtils.h | 2 + - aten/src/ATen/core/ATenDispatch.cpp | 28 + - aten/src/ATen/core/ATenDispatch.h | 6 +- - aten/src/ATen/core/TensorBody.h | 11 +- - aten/src/ATen/core/TensorMethods.h | 982 +++++++++++---------- - aten/src/ATen/core/Variadic.h | 74 ++ - aten/src/ATen/core/aten_interned_strings.h | 1 - - aten/src/ATen/function_wrapper.py | 55 +- - aten/src/ATen/native/BinaryOps.cpp | 59 +- - aten/src/ATen/native/LegacyBridge.cpp | 79 -- - aten/src/ATen/native/native_functions.yaml | 147 ++- - aten/src/ATen/native/sparse/SparseTensorMath.cpp | 102 ++- - aten/src/ATen/native/sparse/SparseTensorMath.h | 11 + - .../ATen/native/sparse/cuda/SparseCUDATensor.cpp | 2 +- - .../native/sparse/cuda/SparseCUDATensorMath.cu | 52 +- - aten/src/ATen/templates/TensorBody.h | 11 +- - aten/src/ATen/templates/TensorMethods.h | 22 + - c10/core/TensorTypeId.h | 6 +- - test/test_nn.py | 3 +- - test/test_sparse.py | 6 +- - tools/autograd/templates/Functions.cpp | 2 +- - torch/csrc/utils/variadic.h | 64 +- - torch/lib/c10d/ProcessGroupGloo.cpp | 8 + - 23 files changed, 948 insertions(+), 785 deletions(-) - create mode 100644 aten/src/ATen/core/Variadic.h - create mode 100644 aten/src/ATen/native/sparse/SparseTensorMath.h - -diff --git a/aten/src/ATen/SparseTensorUtils.h b/aten/src/ATen/SparseTensorUtils.h -index ecc52b2cb3..45aa79eef9 100644 ---- a/aten/src/ATen/SparseTensorUtils.h -+++ b/aten/src/ATen/SparseTensorUtils.h -@@ -1,3 +1,5 @@ -+#pragma once -+ - #include - #include - -diff --git a/aten/src/ATen/core/ATenDispatch.cpp b/aten/src/ATen/core/ATenDispatch.cpp -index 26deaef09a..50ade874af 100644 ---- a/aten/src/ATen/core/ATenDispatch.cpp -+++ b/aten/src/ATen/core/ATenDispatch.cpp -@@ -7,4 +7,32 @@ ATenDispatch & globalATenDispatch() { - return singleton; - } - -+void* ATenOpTable::getFallbackOp(TensorTypeId tid) const { -+ // TODO: an alternate strategy here would be to mask out the dead key -+ // and then redispatch gain (automatic delegation). I haven't done this -+ // for now to make it easier to smoke out error cases. -+ if (function_table_[static_cast(TensorTypeId::UndefinedTensorId)] == nullptr) { -+ // If there is no fallback dispatch, and dispatch failed because we didn't -+ // find any valid keys to dispatch on, this usually means the user gave -+ // us a non-empty list of tensors. So report a better error in this case. -+ // TODO: Maybe we should reword this error message -+ if (tid == TensorTypeId::UndefinedTensorId) { -+ TORCH_CHECK(false, "expected a non-empty list of Tensors") -+ } -+ std::ostringstream oss; -+ bool first = true; -+ for (int64_t i = 0; i < static_cast(TensorTypeId::NumTensorIds); i++) { -+ if (function_table_[i] != nullptr) { -+ if (!first) oss << ", "; -+ oss << toString(static_cast(i)); -+ first = false; -+ } -+ } -+ TORCH_CHECK(false, -+ "No function is registered for schema ", schema_, " on tensor type ", toString(tid), -+ "; available functions are ", oss.str()); -+ } -+ return function_table_[static_cast(TensorTypeId::UndefinedTensorId)]; -+} -+ - } // namespace at -diff --git a/aten/src/ATen/core/ATenDispatch.h b/aten/src/ATen/core/ATenDispatch.h -index ba2940feb3..af16f880fb 100644 ---- a/aten/src/ATen/core/ATenDispatch.h -+++ b/aten/src/ATen/core/ATenDispatch.h -@@ -58,6 +58,8 @@ class CAFFE2_API ATenOpTable { - function_table_[static_cast(tid)] = fn; - } - -+ void* getFallbackOp(TensorTypeId tid) const; -+ - void* getOp(TensorTypeId tid) const { - // You might think we can minorly optimize this further by maintaining a - // bitmask of registered operator keys, so we don't select dispatch ids -@@ -65,9 +67,7 @@ class CAFFE2_API ATenOpTable { - // get a Variable CPUTensor, if there is no variable registration, you'll - // fall back to the CPU implementation. Is this what you want? Unlikely... - if (function_table_[static_cast(tid)] == nullptr) { -- TORCH_CHECK(function_table_[static_cast(TensorTypeId::UndefinedTensorId)] != nullptr, -- "No function is registered for schema ", schema_, " on tensor type ", toString(tid)); -- return function_table_[static_cast(TensorTypeId::UndefinedTensorId)]; -+ return getFallbackOp(tid); - } - return function_table_[static_cast(tid)]; - } -diff --git a/aten/src/ATen/core/TensorBody.h b/aten/src/ATen/core/TensorBody.h -index ded8489538..d150bb57bd 100644 ---- a/aten/src/ATen/core/TensorBody.h -+++ b/aten/src/ATen/core/TensorBody.h -@@ -919,7 +919,7 @@ protected: - }; - - namespace detail { --// Helper creator for Tensor clas which doesn't requires the users to pass -+// Helper creator for Tensor class which doesn't requires the users to pass - // in an intrusive_ptr instead it just converts the argument passed to - // requested intrusive_ptr type. - template -@@ -927,15 +927,6 @@ Tensor make_tensor(Args&&... args) { - return Tensor(c10::make_intrusive(std::forward(args)...)); - } - --inline TensorTypeSet infer_tensor_type_set(const Tensor & tl) { -- return tl.type_set(); --} -- --inline TensorTypeSet infer_tensor_type_set(TensorList tl) { -- TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); -- return tl[0].type_set(); --} -- - } // namespace detail - - static inline TensorTypeId legacyExtractTypeId(const Tensor& t) { -diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h -index 5c9cdda828..03b7de6675 100644 ---- a/aten/src/ATen/core/TensorMethods.h -+++ b/aten/src/ATen/core/TensorMethods.h -@@ -11,6 +11,7 @@ - #if !defined(CAFFE2_IS_XPLAT_BUILD) - #include - #endif -+#include - #include - #ifdef USE_STATIC_DISPATCH - #include -@@ -21,6 +22,27 @@ - - namespace at { - -+namespace detail { -+ -+struct MultiDispatchTensorTypeSet : IterArgs { -+ TensorTypeSet ts; -+ void operator()(const at::Tensor& x) { -+ ts = ts | x.type_set(); -+ } -+ void operator()(at::ArrayRef xs) { -+ for (const auto& x : xs) { -+ ts = ts | x.type_set(); -+ } -+ } -+}; -+ -+template -+TensorTypeSet multi_dispatch_tensor_type_set(Args&&... args) { -+ return MultiDispatchTensorTypeSet().apply(std::forward(args)...).ts; -+} -+ -+} -+ - struct Quantizer; - // This is temporary typedef to enable Quantizer in aten native function API - // we'll remove them when we are actually exposing Quantizer class -@@ -62,7 +84,7 @@ inline void Tensor::backward(const Tensor & gradient, bool keep_graph, bool crea - TypeDefault::backward(const_cast(*this), gradient, keep_graph, create_graph); - #else - static auto table = globalATenDispatch().getOpTable("aten::backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void"); -- return table->getOp(type_set())(const_cast(*this), gradient, keep_graph, create_graph); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, gradient))(const_cast(*this), gradient, keep_graph, create_graph); - #endif - } - inline void Tensor::set_data(const Tensor & new_data) const { -@@ -70,7 +92,7 @@ inline void Tensor::set_data(const Tensor & new_data) const { - TypeDefault::set_data(const_cast(*this), new_data); - #else - static auto table = globalATenDispatch().getOpTable("aten::set_data(Tensor(a!) self, Tensor new_data) -> void"); -- return table->getOp(type_set())(const_cast(*this), new_data); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, new_data))(const_cast(*this), new_data); - #endif - } - inline Tensor Tensor::data() const { -@@ -78,7 +100,7 @@ inline Tensor Tensor::data() const { - return TypeDefault::data(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::data(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -87,7 +109,7 @@ inline Tensor & Tensor::names_(c10::optional names) const { - return TypeDefault::names_(const_cast(*this), names); - #else - static auto table = globalATenDispatch().getOpTable("aten::names_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)"); -- return table->getOp)>(type_set())(const_cast(*this), names); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); - #endif - } - #endif -@@ -97,7 +119,7 @@ inline Tensor Tensor::renamed(c10::optional names) const { - return TypeDefault::renamed(const_cast(*this), names); - #else - static auto table = globalATenDispatch().getOpTable("aten::renamed(Tensor(a) self, Dimname[]? names) -> Tensor(a)"); -- return table->getOp)>(type_set())(const_cast(*this), names); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); - #endif - } - #endif -@@ -107,7 +129,7 @@ inline Tensor Tensor::align_to(DimnameList names) const { - return TypeDefault::align_to(const_cast(*this), names); - #else - static auto table = globalATenDispatch().getOpTable("aten::align_to(Tensor(a) self, DimnameList names) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), names); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); - #endif - } - #endif -@@ -117,7 +139,7 @@ inline Tensor Tensor::align_as(const Tensor & other) const { - return TypeDefault::align_as(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::align_as(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - #endif -@@ -127,7 +149,7 @@ inline Tensor Tensor::refine_names(DimnameList names) const { - return TypeDefault::refine_names(const_cast(*this), names); - #else - static auto table = globalATenDispatch().getOpTable("aten::refine_names(Tensor(a) self, DimnameList names) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), names); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); - #endif - } - #endif -@@ -136,7 +158,7 @@ inline Tensor Tensor::abs() const { - return TypeDefault::abs(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::abs(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::abs_() const { -@@ -150,7 +172,7 @@ inline Tensor & Tensor::abs_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::abs_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::acos() const { -@@ -158,7 +180,7 @@ inline Tensor Tensor::acos() const { - return TypeDefault::acos(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::acos(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::acos_() const { -@@ -172,7 +194,7 @@ inline Tensor & Tensor::acos_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::acos_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const { -@@ -189,7 +211,7 @@ inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, alpha); - #endif - } - inline Tensor & Tensor::add_(const Tensor & other, Scalar alpha) const { -@@ -206,7 +228,7 @@ inline Tensor & Tensor::add_(const Tensor & other, Scalar alpha) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, alpha); - #endif - } - inline Tensor Tensor::add(Scalar other, Scalar alpha) const { -@@ -214,7 +236,7 @@ inline Tensor Tensor::add(Scalar other, Scalar alpha) const { - return TypeDefault::add(const_cast(*this), other, alpha); - #else - static auto table = globalATenDispatch().getOpTable("aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other, alpha); - #endif - } - inline Tensor & Tensor::add_(Scalar other, Scalar alpha) const { -@@ -222,7 +244,7 @@ inline Tensor & Tensor::add_(Scalar other, Scalar alpha) const { - return TypeDefault::add_(const_cast(*this), other, alpha); - #else - static auto table = globalATenDispatch().getOpTable("aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other, alpha); - #endif - } - inline Tensor Tensor::addmv(const Tensor & mat, const Tensor & vec, Scalar beta, Scalar alpha) const { -@@ -236,7 +258,7 @@ inline Tensor Tensor::addmv(const Tensor & mat, const Tensor & vec, Scalar beta, - } - #else - static auto table = globalATenDispatch().getOpTable("aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mat, vec, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat, vec))(const_cast(*this), mat, vec, beta, alpha); - #endif - } - inline Tensor & Tensor::addmv_(const Tensor & mat, const Tensor & vec, Scalar beta, Scalar alpha) const { -@@ -250,7 +272,7 @@ inline Tensor & Tensor::addmv_(const Tensor & mat, const Tensor & vec, Scalar be - } - #else - static auto table = globalATenDispatch().getOpTable("aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), mat, vec, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat, vec))(const_cast(*this), mat, vec, beta, alpha); - #endif - } - inline Tensor Tensor::addr(const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha) const { -@@ -258,7 +280,7 @@ inline Tensor Tensor::addr(const Tensor & vec1, const Tensor & vec2, Scalar beta - return TypeDefault::addr(const_cast(*this), vec1, vec2, beta, alpha); - #else - static auto table = globalATenDispatch().getOpTable("aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), vec1, vec2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, vec1, vec2))(const_cast(*this), vec1, vec2, beta, alpha); - #endif - } - inline Tensor & Tensor::addr_(const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha) const { -@@ -266,7 +288,7 @@ inline Tensor & Tensor::addr_(const Tensor & vec1, const Tensor & vec2, Scalar b - return TypeDefault::addr_(const_cast(*this), vec1, vec2, beta, alpha); - #else - static auto table = globalATenDispatch().getOpTable("aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), vec1, vec2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, vec1, vec2))(const_cast(*this), vec1, vec2, beta, alpha); - #endif - } - inline Tensor Tensor::all(int64_t dim, bool keepdim) const { -@@ -274,7 +296,7 @@ inline Tensor Tensor::all(int64_t dim, bool keepdim) const { - return TypeDefault::all(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - inline bool Tensor::allclose(const Tensor & other, double rtol, double atol, bool equal_nan) const { -@@ -282,7 +304,7 @@ inline bool Tensor::allclose(const Tensor & other, double rtol, double atol, boo - return TypeDefault::allclose(const_cast(*this), other, rtol, atol, equal_nan); - #else - static auto table = globalATenDispatch().getOpTable("aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool"); -- return table->getOp(type_set())(const_cast(*this), other, rtol, atol, equal_nan); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, rtol, atol, equal_nan); - #endif - } - inline Tensor Tensor::any(int64_t dim, bool keepdim) const { -@@ -290,7 +312,7 @@ inline Tensor Tensor::any(int64_t dim, bool keepdim) const { - return TypeDefault::any(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - inline Tensor Tensor::argmax(c10::optional dim, bool keepdim) const { -@@ -298,7 +320,7 @@ inline Tensor Tensor::argmax(c10::optional dim, bool keepdim) const { - return TypeDefault::argmax(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"); -- return table->getOp, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - inline Tensor Tensor::argmin(c10::optional dim, bool keepdim) const { -@@ -306,7 +328,7 @@ inline Tensor Tensor::argmin(c10::optional dim, bool keepdim) const { - return TypeDefault::argmin(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"); -- return table->getOp, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) const { -@@ -323,7 +345,7 @@ inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::opti - } - #else - static auto table = globalATenDispatch().getOpTable("aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)"); -- return table->getOp)>(type_set())(const_cast(*this), size, stride, storage_offset); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, stride, storage_offset); - #endif - } - inline Tensor & Tensor::as_strided_(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) const { -@@ -331,7 +353,7 @@ inline Tensor & Tensor::as_strided_(IntArrayRef size, IntArrayRef stride, c10::o - return TypeDefault::as_strided_(const_cast(*this), size, stride, storage_offset); - #else - static auto table = globalATenDispatch().getOpTable("aten::as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!)"); -- return table->getOp)>(type_set())(const_cast(*this), size, stride, storage_offset); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, stride, storage_offset); - #endif - } - inline Tensor Tensor::asin() const { -@@ -339,7 +361,7 @@ inline Tensor Tensor::asin() const { - return TypeDefault::asin(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::asin(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::asin_() const { -@@ -353,7 +375,7 @@ inline Tensor & Tensor::asin_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::asin_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::atan() const { -@@ -361,7 +383,7 @@ inline Tensor Tensor::atan() const { - return TypeDefault::atan(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::atan(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::atan_() const { -@@ -375,7 +397,7 @@ inline Tensor & Tensor::atan_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::atan_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::baddbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { -@@ -389,7 +411,7 @@ inline Tensor Tensor::baddbmm(const Tensor & batch1, const Tensor & batch2, Scal - } - #else - static auto table = globalATenDispatch().getOpTable("aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), batch1, batch2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))(const_cast(*this), batch1, batch2, beta, alpha); - #endif - } - inline Tensor & Tensor::baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { -@@ -403,7 +425,7 @@ inline Tensor & Tensor::baddbmm_(const Tensor & batch1, const Tensor & batch2, S - } - #else - static auto table = globalATenDispatch().getOpTable("aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), batch1, batch2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))(const_cast(*this), batch1, batch2, beta, alpha); - #endif - } - inline Tensor Tensor::bernoulli(Generator * generator) const { -@@ -411,7 +433,7 @@ inline Tensor Tensor::bernoulli(Generator * generator) const { - return TypeDefault::bernoulli(const_cast(*this), generator); - #else - static auto table = globalATenDispatch().getOpTable("aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), generator); - #endif - } - inline Tensor & Tensor::bernoulli_(const Tensor & p, Generator * generator) const { -@@ -425,7 +447,7 @@ inline Tensor & Tensor::bernoulli_(const Tensor & p, Generator * generator) cons - } - #else - static auto table = globalATenDispatch().getOpTable("aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), p, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, p))(const_cast(*this), p, generator); - #endif - } - inline Tensor & Tensor::bernoulli_(double p, Generator * generator) const { -@@ -439,7 +461,7 @@ inline Tensor & Tensor::bernoulli_(double p, Generator * generator) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), p, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); - #endif - } - inline Tensor Tensor::bernoulli(double p, Generator * generator) const { -@@ -447,7 +469,7 @@ inline Tensor Tensor::bernoulli(double p, Generator * generator) const { - return TypeDefault::bernoulli(const_cast(*this), p, generator); - #else - static auto table = globalATenDispatch().getOpTable("aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), p, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); - #endif - } - inline Tensor Tensor::bincount(const Tensor & weights, int64_t minlength) const { -@@ -461,7 +483,7 @@ inline Tensor Tensor::bincount(const Tensor & weights, int64_t minlength) const - } - #else - static auto table = globalATenDispatch().getOpTable("aten::bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), weights, minlength); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, weights))(const_cast(*this), weights, minlength); - #endif - } - inline Tensor Tensor::bitwise_not() const { -@@ -469,7 +491,7 @@ inline Tensor Tensor::bitwise_not() const { - return TypeDefault::bitwise_not(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::bitwise_not(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::bitwise_not_() const { -@@ -477,7 +499,7 @@ inline Tensor & Tensor::bitwise_not_() const { - return TypeDefault::bitwise_not_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::logical_not() const { -@@ -485,7 +507,7 @@ inline Tensor Tensor::logical_not() const { - return TypeDefault::logical_not(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::logical_not(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::logical_not_() const { -@@ -493,7 +515,7 @@ inline Tensor & Tensor::logical_not_() const { - return TypeDefault::logical_not_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::logical_not_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::logical_xor(const Tensor & other) const { -@@ -501,7 +523,7 @@ inline Tensor Tensor::logical_xor(const Tensor & other) const { - return TypeDefault::logical_xor(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::logical_xor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::logical_xor_(const Tensor & other) const { -@@ -509,7 +531,7 @@ inline Tensor & Tensor::logical_xor_(const Tensor & other) const { - return TypeDefault::logical_xor_(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::bmm(const Tensor & mat2) const { -@@ -523,7 +545,7 @@ inline Tensor Tensor::bmm(const Tensor & mat2) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::bmm(Tensor self, Tensor mat2) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mat2); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat2))(const_cast(*this), mat2); - #endif - } - inline Tensor Tensor::ceil() const { -@@ -531,7 +553,7 @@ inline Tensor Tensor::ceil() const { - return TypeDefault::ceil(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::ceil(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::ceil_() const { -@@ -539,7 +561,7 @@ inline Tensor & Tensor::ceil_() const { - return TypeDefault::ceil_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::ceil_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline std::vector Tensor::chunk(int64_t chunks, int64_t dim) const { -@@ -547,7 +569,7 @@ inline std::vector Tensor::chunk(int64_t chunks, int64_t dim) const { - return TypeDefault::chunk(const_cast(*this), chunks, dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::chunk(Tensor(a) self, int chunks, int dim=0) -> Tensor(a)[]"); -- return table->getOp (const Tensor &, int64_t, int64_t)>(type_set())(const_cast(*this), chunks, dim); -+ return table->getOp (const Tensor &, int64_t, int64_t)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), chunks, dim); - #endif - } - inline Tensor Tensor::clamp(c10::optional min, c10::optional max) const { -@@ -555,7 +577,7 @@ inline Tensor Tensor::clamp(c10::optional min, c10::optional max - return TypeDefault::clamp(const_cast(*this), min, max); - #else - static auto table = globalATenDispatch().getOpTable("aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"); -- return table->getOp, c10::optional)>(type_set())(const_cast(*this), min, max); -+ return table->getOp, c10::optional)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), min, max); - #endif - } - inline Tensor & Tensor::clamp_(c10::optional min, c10::optional max) const { -@@ -569,7 +591,7 @@ inline Tensor & Tensor::clamp_(c10::optional min, c10::optional - } - #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)"); -- return table->getOp, c10::optional)>(type_set())(const_cast(*this), min, max); -+ return table->getOp, c10::optional)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), min, max); - #endif - } - inline Tensor Tensor::clamp_max(Scalar max) const { -@@ -577,7 +599,7 @@ inline Tensor Tensor::clamp_max(Scalar max) const { - return TypeDefault::clamp_max(const_cast(*this), max); - #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_max(Tensor self, Scalar max) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), max); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), max); - #endif - } - inline Tensor & Tensor::clamp_max_(Scalar max) const { -@@ -591,7 +613,7 @@ inline Tensor & Tensor::clamp_max_(Scalar max) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), max); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), max); - #endif - } - inline Tensor Tensor::clamp_min(Scalar min) const { -@@ -599,7 +621,7 @@ inline Tensor Tensor::clamp_min(Scalar min) const { - return TypeDefault::clamp_min(const_cast(*this), min); - #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_min(Tensor self, Scalar min) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), min); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), min); - #endif - } - inline Tensor & Tensor::clamp_min_(Scalar min) const { -@@ -613,7 +635,7 @@ inline Tensor & Tensor::clamp_min_(Scalar min) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), min); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), min); - #endif - } - inline Tensor Tensor::contiguous(MemoryFormat memory_format) const { -@@ -621,7 +643,7 @@ inline Tensor Tensor::contiguous(MemoryFormat memory_format) const { - return TypeDefault::contiguous(const_cast(*this), memory_format); - #else - static auto table = globalATenDispatch().getOpTable("aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), memory_format); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), memory_format); - #endif - } - inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { -@@ -629,7 +651,7 @@ inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { - return TypeDefault::copy_(const_cast(*this), src, non_blocking); - #else - static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), src, non_blocking); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast(*this), src, non_blocking); - #endif - } - inline Tensor Tensor::cos() const { -@@ -637,7 +659,7 @@ inline Tensor Tensor::cos() const { - return TypeDefault::cos(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::cos(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::cos_() const { -@@ -651,7 +673,7 @@ inline Tensor & Tensor::cos_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::cos_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::cosh() const { -@@ -659,7 +681,7 @@ inline Tensor Tensor::cosh() const { - return TypeDefault::cosh(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::cosh(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::cosh_() const { -@@ -673,7 +695,7 @@ inline Tensor & Tensor::cosh_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::cosh_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::cumsum(int64_t dim, c10::optional dtype) const { -@@ -681,7 +703,7 @@ inline Tensor Tensor::cumsum(int64_t dim, c10::optional dtype) const - return TypeDefault::cumsum(const_cast(*this), dim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); - #endif - } - inline Tensor Tensor::cumprod(int64_t dim, c10::optional dtype) const { -@@ -689,7 +711,7 @@ inline Tensor Tensor::cumprod(int64_t dim, c10::optional dtype) cons - return TypeDefault::cumprod(const_cast(*this), dim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); - #endif - } - inline Tensor Tensor::det() const { -@@ -697,7 +719,7 @@ inline Tensor Tensor::det() const { - return TypeDefault::det(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::det(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) const { -@@ -705,7 +727,7 @@ inline Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) con - return TypeDefault::diag_embed(const_cast(*this), offset, dim1, dim2); - #else - static auto table = globalATenDispatch().getOpTable("aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), offset, dim1, dim2); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), offset, dim1, dim2); - #endif - } - inline Tensor Tensor::diagflat(int64_t offset) const { -@@ -713,7 +735,7 @@ inline Tensor Tensor::diagflat(int64_t offset) const { - return TypeDefault::diagflat(const_cast(*this), offset); - #else - static auto table = globalATenDispatch().getOpTable("aten::diagflat(Tensor self, int offset=0) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), offset); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), offset); - #endif - } - inline Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const { -@@ -721,7 +743,7 @@ inline Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const - return TypeDefault::diagonal(const_cast(*this), offset, dim1, dim2); - #else - static auto table = globalATenDispatch().getOpTable("aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), offset, dim1, dim2); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), offset, dim1, dim2); - #endif - } - inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) const { -@@ -729,23 +751,41 @@ inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) const { - return TypeDefault::fill_diagonal_(const_cast(*this), fill_value, wrap); - #else - static auto table = globalATenDispatch().getOpTable("aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), fill_value, wrap); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), fill_value, wrap); - #endif - } - inline Tensor Tensor::div(const Tensor & other) const { - #ifdef USE_STATIC_DISPATCH -- return TypeDefault::div(const_cast(*this), other); -+ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { -+ case Backend::CPU: -+ return CPUType::div(const_cast(*this), other); -+ break; -+ case Backend::SparseCPU: -+ return SparseCPUType::div(const_cast(*this), other); -+ break; -+ default: -+ AT_ERROR("div not implemented for ", at::toString(type_set())); -+ } - #else - static auto table = globalATenDispatch().getOpTable("aten::div.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::div_(const Tensor & other) const { - #ifdef USE_STATIC_DISPATCH -- return TypeDefault::div_(const_cast(*this), other); -+ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { -+ case Backend::CPU: -+ return CPUType::div_(const_cast(*this), other); -+ break; -+ case Backend::SparseCPU: -+ return SparseCPUType::div_(const_cast(*this), other); -+ break; -+ default: -+ AT_ERROR("div_ not implemented for ", at::toString(type_set())); -+ } - #else - static auto table = globalATenDispatch().getOpTable("aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::div(Scalar other) const { -@@ -753,7 +793,7 @@ inline Tensor Tensor::div(Scalar other) const { - return TypeDefault::div(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::div.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::div_(Scalar other) const { -@@ -761,7 +801,7 @@ inline Tensor & Tensor::div_(Scalar other) const { - return TypeDefault::div_(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::dot(const Tensor & tensor) const { -@@ -775,7 +815,7 @@ inline Tensor Tensor::dot(const Tensor & tensor) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::dot(Tensor self, Tensor tensor) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), tensor); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor))(const_cast(*this), tensor); - #endif - } - inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options) const { -@@ -783,7 +823,7 @@ inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options) - return TypeDefault::new_empty(const_cast(*this), size, options); - #else - static auto table = globalATenDispatch().getOpTable("aten::new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), size, options); -+ return table->getOp(type_set(/* HMMMM */))(const_cast(*this), size, options); - #endif - } - inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options) const { -@@ -791,7 +831,7 @@ inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const Tensor - return TypeDefault::new_full(const_cast(*this), size, fill_value, options); - #else - static auto table = globalATenDispatch().getOpTable("aten::new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), size, fill_value, options); -+ return table->getOp(type_set(/* HMMMM */))(const_cast(*this), size, fill_value, options); - #endif - } - inline Tensor & Tensor::resize_(IntArrayRef size) const { -@@ -805,7 +845,7 @@ inline Tensor & Tensor::resize_(IntArrayRef size) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::resize_(Tensor(a!) self, int[] size) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), size); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size); - #endif - } - inline Tensor Tensor::erf() const { -@@ -813,7 +853,7 @@ inline Tensor Tensor::erf() const { - return TypeDefault::erf(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::erf(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::erf_() const { -@@ -827,7 +867,7 @@ inline Tensor & Tensor::erf_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::erf_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::erfc() const { -@@ -835,7 +875,7 @@ inline Tensor Tensor::erfc() const { - return TypeDefault::erfc(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::erfc(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::erfc_() const { -@@ -849,7 +889,7 @@ inline Tensor & Tensor::erfc_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::erfc_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::exp() const { -@@ -857,7 +897,7 @@ inline Tensor Tensor::exp() const { - return TypeDefault::exp(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::exp(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::exp_() const { -@@ -871,7 +911,7 @@ inline Tensor & Tensor::exp_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::exp_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::expm1() const { -@@ -879,7 +919,7 @@ inline Tensor Tensor::expm1() const { - return TypeDefault::expm1(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::expm1(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::expm1_() const { -@@ -893,7 +933,7 @@ inline Tensor & Tensor::expm1_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::expm1_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::expand(IntArrayRef size, bool implicit) const { -@@ -901,7 +941,7 @@ inline Tensor Tensor::expand(IntArrayRef size, bool implicit) const { - return TypeDefault::expand(const_cast(*this), size, implicit); - #else - static auto table = globalATenDispatch().getOpTable("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), size, implicit); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, implicit); - #endif - } - inline Tensor Tensor::expand_as(const Tensor & other) const { -@@ -909,7 +949,7 @@ inline Tensor Tensor::expand_as(const Tensor & other) const { - return TypeDefault::expand_as(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::expand_as(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const { -@@ -917,7 +957,7 @@ inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const { - return TypeDefault::flatten(const_cast(*this), start_dim, end_dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), start_dim, end_dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), start_dim, end_dim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -926,7 +966,7 @@ inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim, Dimname out_di - return TypeDefault::flatten(const_cast(*this), start_dim, end_dim, out_dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), start_dim, end_dim, out_dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), start_dim, end_dim, out_dim); - #endif - } - #endif -@@ -936,7 +976,7 @@ inline Tensor Tensor::flatten(Dimname start_dim, Dimname end_dim, Dimname out_di - return TypeDefault::flatten(const_cast(*this), start_dim, end_dim, out_dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), start_dim, end_dim, out_dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), start_dim, end_dim, out_dim); - #endif - } - #endif -@@ -946,7 +986,7 @@ inline Tensor Tensor::flatten(DimnameList dims, Dimname out_dim) const { - return TypeDefault::flatten(const_cast(*this), dims, out_dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dims, out_dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dims, out_dim); - #endif - } - #endif -@@ -955,7 +995,7 @@ inline Tensor & Tensor::fill_(Scalar value) const { - return TypeDefault::fill_(const_cast(*this), value); - #else - static auto table = globalATenDispatch().getOpTable("aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), value); - #endif - } - inline Tensor & Tensor::fill_(const Tensor & value) const { -@@ -963,7 +1003,7 @@ inline Tensor & Tensor::fill_(const Tensor & value) const { - return TypeDefault::fill_(const_cast(*this), value); - #else - static auto table = globalATenDispatch().getOpTable("aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, value))(const_cast(*this), value); - #endif - } - inline Tensor Tensor::floor() const { -@@ -971,7 +1011,7 @@ inline Tensor Tensor::floor() const { - return TypeDefault::floor(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::floor(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::floor_() const { -@@ -985,7 +1025,7 @@ inline Tensor & Tensor::floor_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::floor_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::frac() const { -@@ -993,7 +1033,7 @@ inline Tensor Tensor::frac() const { - return TypeDefault::frac(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::frac(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::frac_() const { -@@ -1007,7 +1047,7 @@ inline Tensor & Tensor::frac_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::frac_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::ger(const Tensor & vec2) const { -@@ -1021,7 +1061,7 @@ inline Tensor Tensor::ger(const Tensor & vec2) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ger(Tensor self, Tensor vec2) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), vec2); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, vec2))(const_cast(*this), vec2); - #endif - } - inline Tensor Tensor::fft(int64_t signal_ndim, bool normalized) const { -@@ -1029,7 +1069,7 @@ inline Tensor Tensor::fft(int64_t signal_ndim, bool normalized) const { - return TypeDefault::fft(const_cast(*this), signal_ndim, normalized); - #else - static auto table = globalATenDispatch().getOpTable("aten::fft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), signal_ndim, normalized); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), signal_ndim, normalized); - #endif - } - inline Tensor Tensor::ifft(int64_t signal_ndim, bool normalized) const { -@@ -1037,7 +1077,7 @@ inline Tensor Tensor::ifft(int64_t signal_ndim, bool normalized) const { - return TypeDefault::ifft(const_cast(*this), signal_ndim, normalized); - #else - static auto table = globalATenDispatch().getOpTable("aten::ifft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), signal_ndim, normalized); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), signal_ndim, normalized); - #endif - } - inline Tensor Tensor::rfft(int64_t signal_ndim, bool normalized, bool onesided) const { -@@ -1045,7 +1085,7 @@ inline Tensor Tensor::rfft(int64_t signal_ndim, bool normalized, bool onesided) - return TypeDefault::rfft(const_cast(*this), signal_ndim, normalized, onesided); - #else - static auto table = globalATenDispatch().getOpTable("aten::rfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), signal_ndim, normalized, onesided); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), signal_ndim, normalized, onesided); - #endif - } - inline Tensor Tensor::irfft(int64_t signal_ndim, bool normalized, bool onesided, IntArrayRef signal_sizes) const { -@@ -1053,7 +1093,7 @@ inline Tensor Tensor::irfft(int64_t signal_ndim, bool normalized, bool onesided, - return TypeDefault::irfft(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); - #else - static auto table = globalATenDispatch().getOpTable("aten::irfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True, int[] signal_sizes=[]) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); - #endif - } - inline Tensor Tensor::index(TensorList indices) const { -@@ -1061,7 +1101,7 @@ inline Tensor Tensor::index(TensorList indices) const { - return TypeDefault::index(const_cast(*this), indices); - #else - static auto table = globalATenDispatch().getOpTable("aten::index(Tensor self, Tensor?[] indices) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), indices); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices))(const_cast(*this), indices); - #endif - } - inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) const { -@@ -1069,7 +1109,7 @@ inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Ten - return TypeDefault::index_copy_(const_cast(*this), dim, index, source); - #else - static auto table = globalATenDispatch().getOpTable("aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim, index, source); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), dim, index, source); - #endif - } - inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const { -@@ -1077,7 +1117,7 @@ inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor - return TypeDefault::index_copy(const_cast(*this), dim, index, source); - #else - static auto table = globalATenDispatch().getOpTable("aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index, source); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), dim, index, source); - #endif - } - inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) const { -@@ -1085,7 +1125,7 @@ inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bo - return TypeDefault::index_put_(const_cast(*this), indices, values, accumulate); - #else - static auto table = globalATenDispatch().getOpTable("aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), indices, values, accumulate); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices, values))(const_cast(*this), indices, values, accumulate); - #endif - } - inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const { -@@ -1093,7 +1133,7 @@ inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool - return TypeDefault::index_put(const_cast(*this), indices, values, accumulate); - #else - static auto table = globalATenDispatch().getOpTable("aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), indices, values, accumulate); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices, values))(const_cast(*this), indices, values, accumulate); - #endif - } - inline Tensor Tensor::inverse() const { -@@ -1101,7 +1141,7 @@ inline Tensor Tensor::inverse() const { - return TypeDefault::inverse(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::inverse(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::isclose(const Tensor & other, double rtol, double atol, bool equal_nan) const { -@@ -1109,7 +1149,7 @@ inline Tensor Tensor::isclose(const Tensor & other, double rtol, double atol, bo - return TypeDefault::isclose(const_cast(*this), other, rtol, atol, equal_nan); - #else - static auto table = globalATenDispatch().getOpTable("aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other, rtol, atol, equal_nan); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, rtol, atol, equal_nan); - #endif - } - inline bool Tensor::is_distributed() const { -@@ -1117,7 +1157,7 @@ inline bool Tensor::is_distributed() const { - return TypeDefault::is_distributed(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::is_distributed(Tensor self) -> bool"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline bool Tensor::is_floating_point() const { -@@ -1125,7 +1165,7 @@ inline bool Tensor::is_floating_point() const { - return TypeDefault::is_floating_point(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::is_floating_point(Tensor self) -> bool"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline bool Tensor::is_complex() const { -@@ -1133,7 +1173,7 @@ inline bool Tensor::is_complex() const { - return TypeDefault::is_complex(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::is_complex(Tensor self) -> bool"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline bool Tensor::is_nonzero() const { -@@ -1141,7 +1181,7 @@ inline bool Tensor::is_nonzero() const { - return TypeDefault::is_nonzero(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::is_nonzero(Tensor self) -> bool"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline bool Tensor::is_same_size(const Tensor & other) const { -@@ -1149,7 +1189,7 @@ inline bool Tensor::is_same_size(const Tensor & other) const { - return TypeDefault::is_same_size(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::is_same_size(Tensor self, Tensor other) -> bool"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline bool Tensor::is_signed() const { -@@ -1157,7 +1197,7 @@ inline bool Tensor::is_signed() const { - return TypeDefault::is_signed(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::is_signed(Tensor self) -> bool"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool keepdim) const { -@@ -1165,7 +1205,7 @@ inline std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool k - return TypeDefault::kthvalue(const_cast(*this), k, dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, int64_t, int64_t, bool)>(type_set())(const_cast(*this), k, dim, keepdim); -+ return table->getOp (const Tensor &, int64_t, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), k, dim, keepdim); - #endif - } - inline Tensor Tensor::log() const { -@@ -1173,7 +1213,7 @@ inline Tensor Tensor::log() const { - return TypeDefault::log(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::log(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::log_() const { -@@ -1187,7 +1227,7 @@ inline Tensor & Tensor::log_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::log_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::log10() const { -@@ -1195,7 +1235,7 @@ inline Tensor Tensor::log10() const { - return TypeDefault::log10(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::log10(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::log10_() const { -@@ -1209,7 +1249,7 @@ inline Tensor & Tensor::log10_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::log10_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::log1p() const { -@@ -1217,7 +1257,7 @@ inline Tensor Tensor::log1p() const { - return TypeDefault::log1p(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::log1p(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::log1p_() const { -@@ -1234,7 +1274,7 @@ inline Tensor & Tensor::log1p_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::log1p_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::log2() const { -@@ -1242,7 +1282,7 @@ inline Tensor Tensor::log2() const { - return TypeDefault::log2(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::log2(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::log2_() const { -@@ -1256,7 +1296,7 @@ inline Tensor & Tensor::log2_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::log2_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::logdet() const { -@@ -1264,7 +1304,7 @@ inline Tensor Tensor::logdet() const { - return TypeDefault::logdet(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::logdet(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::log_softmax(int64_t dim, c10::optional dtype) const { -@@ -1272,7 +1312,7 @@ inline Tensor Tensor::log_softmax(int64_t dim, c10::optional dtype) - return TypeDefault::log_softmax(const_cast(*this), dim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::log_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1281,7 +1321,7 @@ inline Tensor Tensor::log_softmax(Dimname dim, c10::optional dtype) - return TypeDefault::log_softmax(const_cast(*this), dim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::log_softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); - #endif - } - #endif -@@ -1290,7 +1330,7 @@ inline Tensor Tensor::logsumexp(IntArrayRef dim, bool keepdim) const { - return TypeDefault::logsumexp(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1299,7 +1339,7 @@ inline Tensor Tensor::logsumexp(DimnameList dim, bool keepdim) const { - return TypeDefault::logsumexp(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #endif -@@ -1308,7 +1348,7 @@ inline Tensor Tensor::matmul(const Tensor & other) const { - return TypeDefault::matmul(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::matmul(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::matrix_power(int64_t n) const { -@@ -1316,7 +1356,7 @@ inline Tensor Tensor::matrix_power(int64_t n) const { - return TypeDefault::matrix_power(const_cast(*this), n); - #else - static auto table = globalATenDispatch().getOpTable("aten::matrix_power(Tensor self, int n) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), n); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), n); - #endif - } - inline std::tuple Tensor::max(int64_t dim, bool keepdim) const { -@@ -1324,7 +1364,7 @@ inline std::tuple Tensor::max(int64_t dim, bool keepdim) const { - return TypeDefault::max(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - inline Tensor Tensor::max_values(IntArrayRef dim, bool keepdim) const { -@@ -1332,7 +1372,7 @@ inline Tensor Tensor::max_values(IntArrayRef dim, bool keepdim) const { - return TypeDefault::max_values(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::max_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1341,7 +1381,7 @@ inline std::tuple Tensor::max(Dimname dim, bool keepdim) const { - return TypeDefault::max(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, Dimname, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #endif -@@ -1351,7 +1391,7 @@ inline Tensor Tensor::max_values(DimnameList dim, bool keepdim) const { - return TypeDefault::max_values(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::max_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #endif -@@ -1360,7 +1400,7 @@ inline Tensor Tensor::mean(c10::optional dtype) const { - return TypeDefault::mean(const_cast(*this), dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); - #endif - } - inline Tensor Tensor::mean(IntArrayRef dim, bool keepdim, c10::optional dtype) const { -@@ -1368,7 +1408,7 @@ inline Tensor Tensor::mean(IntArrayRef dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1377,7 +1417,7 @@ inline Tensor Tensor::mean(DimnameList dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); - #endif - } - #endif -@@ -1386,7 +1426,7 @@ inline std::tuple Tensor::median(int64_t dim, bool keepdim) const - return TypeDefault::median(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1395,7 +1435,7 @@ inline std::tuple Tensor::median(Dimname dim, bool keepdim) const - return TypeDefault::median(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, Dimname, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #endif -@@ -1404,7 +1444,7 @@ inline std::tuple Tensor::min(int64_t dim, bool keepdim) const { - return TypeDefault::min(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - inline Tensor Tensor::min_values(IntArrayRef dim, bool keepdim) const { -@@ -1412,7 +1452,7 @@ inline Tensor Tensor::min_values(IntArrayRef dim, bool keepdim) const { - return TypeDefault::min_values(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::min_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1421,7 +1461,7 @@ inline std::tuple Tensor::min(Dimname dim, bool keepdim) const { - return TypeDefault::min(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, Dimname, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #endif -@@ -1431,7 +1471,7 @@ inline Tensor Tensor::min_values(DimnameList dim, bool keepdim) const { - return TypeDefault::min_values(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::min_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - #endif -@@ -1449,7 +1489,7 @@ inline Tensor Tensor::mm(const Tensor & mat2) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::mm(Tensor self, Tensor mat2) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mat2); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat2))(const_cast(*this), mat2); - #endif - } - inline std::tuple Tensor::mode(int64_t dim, bool keepdim) const { -@@ -1457,7 +1497,7 @@ inline std::tuple Tensor::mode(int64_t dim, bool keepdim) const { - return TypeDefault::mode(const_cast(*this), dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, keepdim); -+ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); - #endif - } - inline Tensor Tensor::mul(const Tensor & other) const { -@@ -1474,7 +1514,7 @@ inline Tensor Tensor::mul(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::mul_(const Tensor & other) const { -@@ -1491,7 +1531,7 @@ inline Tensor & Tensor::mul_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::mul(Scalar other) const { -@@ -1499,7 +1539,7 @@ inline Tensor Tensor::mul(Scalar other) const { - return TypeDefault::mul(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::mul_(Scalar other) const { -@@ -1507,7 +1547,7 @@ inline Tensor & Tensor::mul_(Scalar other) const { - return TypeDefault::mul_(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::mv(const Tensor & vec) const { -@@ -1521,7 +1561,7 @@ inline Tensor Tensor::mv(const Tensor & vec) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::mv(Tensor self, Tensor vec) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), vec); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, vec))(const_cast(*this), vec); - #endif - } - inline Tensor Tensor::mvlgamma(int64_t p) const { -@@ -1529,7 +1569,7 @@ inline Tensor Tensor::mvlgamma(int64_t p) const { - return TypeDefault::mvlgamma(const_cast(*this), p); - #else - static auto table = globalATenDispatch().getOpTable("aten::mvlgamma(Tensor self, int p) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), p); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p); - #endif - } - inline Tensor & Tensor::mvlgamma_(int64_t p) const { -@@ -1537,7 +1577,7 @@ inline Tensor & Tensor::mvlgamma_(int64_t p) const { - return TypeDefault::mvlgamma_(const_cast(*this), p); - #else - static auto table = globalATenDispatch().getOpTable("aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), p); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p); - #endif - } - inline Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) const { -@@ -1554,7 +1594,7 @@ inline Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) co - } - #else - static auto table = globalATenDispatch().getOpTable("aten::narrow_copy(Tensor self, int dim, int start, int length) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, start, length); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, start, length); - #endif - } - inline Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const { -@@ -1562,7 +1602,7 @@ inline Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const { - return TypeDefault::narrow(const_cast(*this), dim, start, length); - #else - static auto table = globalATenDispatch().getOpTable("aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dim, start, length); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, start, length); - #endif - } - inline Tensor Tensor::permute(IntArrayRef dims) const { -@@ -1570,7 +1610,7 @@ inline Tensor Tensor::permute(IntArrayRef dims) const { - return TypeDefault::permute(const_cast(*this), dims); - #else - static auto table = globalATenDispatch().getOpTable("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dims); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dims); - #endif - } - inline Tensor Tensor::numpy_T() const { -@@ -1578,7 +1618,7 @@ inline Tensor Tensor::numpy_T() const { - return TypeDefault::numpy_T(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::numpy_T(Tensor(a) self) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline bool Tensor::is_pinned() const { -@@ -1586,7 +1626,7 @@ inline bool Tensor::is_pinned() const { - return TypeDefault::is_pinned(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::is_pinned(Tensor self) -> bool"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::pin_memory() const { -@@ -1594,7 +1634,7 @@ inline Tensor Tensor::pin_memory() const { - return TypeDefault::pin_memory(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::pin_memory(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::pinverse(double rcond) const { -@@ -1602,7 +1642,7 @@ inline Tensor Tensor::pinverse(double rcond) const { - return TypeDefault::pinverse(const_cast(*this), rcond); - #else - static auto table = globalATenDispatch().getOpTable("aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), rcond); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), rcond); - #endif - } - inline Tensor Tensor::reciprocal() const { -@@ -1610,7 +1650,7 @@ inline Tensor Tensor::reciprocal() const { - return TypeDefault::reciprocal(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::reciprocal(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::reciprocal_() const { -@@ -1624,7 +1664,7 @@ inline Tensor & Tensor::reciprocal_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::neg() const { -@@ -1632,7 +1672,7 @@ inline Tensor Tensor::neg() const { - return TypeDefault::neg(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::neg(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::neg_() const { -@@ -1640,7 +1680,7 @@ inline Tensor & Tensor::neg_() const { - return TypeDefault::neg_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::neg_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::repeat(IntArrayRef repeats) const { -@@ -1648,7 +1688,7 @@ inline Tensor Tensor::repeat(IntArrayRef repeats) const { - return TypeDefault::repeat(const_cast(*this), repeats); - #else - static auto table = globalATenDispatch().getOpTable("aten::repeat(Tensor self, int[] repeats) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), repeats); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), repeats); - #endif - } - inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional dim) const { -@@ -1656,7 +1696,7 @@ inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional(*this), repeats, dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), repeats, dim); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this, repeats))(const_cast(*this), repeats, dim); - #endif - } - inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional dim) const { -@@ -1664,7 +1704,7 @@ inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional - return TypeDefault::repeat_interleave(const_cast(*this), repeats, dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), repeats, dim); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), repeats, dim); - #endif - } - inline Tensor Tensor::reshape(IntArrayRef shape) const { -@@ -1672,7 +1712,7 @@ inline Tensor Tensor::reshape(IntArrayRef shape) const { - return TypeDefault::reshape(const_cast(*this), shape); - #else - static auto table = globalATenDispatch().getOpTable("aten::reshape(Tensor self, int[] shape) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), shape); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), shape); - #endif - } - inline Tensor Tensor::reshape_as(const Tensor & other) const { -@@ -1680,7 +1720,7 @@ inline Tensor Tensor::reshape_as(const Tensor & other) const { - return TypeDefault::reshape_as(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::reshape_as(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::round() const { -@@ -1688,7 +1728,7 @@ inline Tensor Tensor::round() const { - return TypeDefault::round(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::round(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::round_() const { -@@ -1702,7 +1742,7 @@ inline Tensor & Tensor::round_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::round_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::relu() const { -@@ -1719,7 +1759,7 @@ inline Tensor Tensor::relu() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::relu(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::relu_() const { -@@ -1736,7 +1776,7 @@ inline Tensor & Tensor::relu_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::relu_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::prelu(const Tensor & weight) const { -@@ -1750,7 +1790,7 @@ inline Tensor Tensor::prelu(const Tensor & weight) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::prelu(Tensor self, Tensor weight) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), weight); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, weight))(const_cast(*this), weight); - #endif - } - inline std::tuple Tensor::prelu_backward(const Tensor & grad_output, const Tensor & weight) const { -@@ -1764,7 +1804,7 @@ inline std::tuple Tensor::prelu_backward(const Tensor & grad_outp - } - #else - static auto table = globalATenDispatch().getOpTable("aten::prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)"); -- return table->getOp (const Tensor &, const Tensor &, const Tensor &)>(type_set())(grad_output, const_cast(*this), weight); -+ return table->getOp (const Tensor &, const Tensor &, const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(grad_output, *this, weight))(grad_output, const_cast(*this), weight); - #endif - } - inline Tensor Tensor::hardshrink(Scalar lambd) const { -@@ -1778,7 +1818,7 @@ inline Tensor Tensor::hardshrink(Scalar lambd) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), lambd); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), lambd); - #endif - } - inline Tensor Tensor::hardshrink_backward(const Tensor & grad_out, Scalar lambd) const { -@@ -1792,7 +1832,7 @@ inline Tensor Tensor::hardshrink_backward(const Tensor & grad_out, Scalar lambd) - } - #else - static auto table = globalATenDispatch().getOpTable("aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor"); -- return table->getOp(type_set())(grad_out, const_cast(*this), lambd); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(grad_out, *this))(grad_out, const_cast(*this), lambd); - #endif - } - inline Tensor Tensor::rsqrt() const { -@@ -1800,7 +1840,7 @@ inline Tensor Tensor::rsqrt() const { - return TypeDefault::rsqrt(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::rsqrt(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::rsqrt_() const { -@@ -1814,7 +1854,7 @@ inline Tensor & Tensor::rsqrt_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1823,7 +1863,7 @@ inline Tensor Tensor::select(Dimname dim, int64_t index) const { - return TypeDefault::select(const_cast(*this), dim, index); - #else - static auto table = globalATenDispatch().getOpTable("aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dim, index); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, index); - #endif - } - #endif -@@ -1832,7 +1872,7 @@ inline Tensor Tensor::select(int64_t dim, int64_t index) const { - return TypeDefault::select(const_cast(*this), dim, index); - #else - static auto table = globalATenDispatch().getOpTable("aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dim, index); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, index); - #endif - } - inline Tensor Tensor::sigmoid() const { -@@ -1846,7 +1886,7 @@ inline Tensor Tensor::sigmoid() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sigmoid(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::sigmoid_() const { -@@ -1860,7 +1900,7 @@ inline Tensor & Tensor::sigmoid_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::sin() const { -@@ -1868,7 +1908,7 @@ inline Tensor Tensor::sin() const { - return TypeDefault::sin(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::sin(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::sin_() const { -@@ -1882,7 +1922,7 @@ inline Tensor & Tensor::sin_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sin_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::sinh() const { -@@ -1890,7 +1930,7 @@ inline Tensor Tensor::sinh() const { - return TypeDefault::sinh(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::sinh(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::sinh_() const { -@@ -1904,7 +1944,7 @@ inline Tensor & Tensor::sinh_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sinh_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::detach() const { -@@ -1912,7 +1952,7 @@ inline Tensor Tensor::detach() const { - return TypeDefault::detach(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::detach(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::detach_() const { -@@ -1920,7 +1960,7 @@ inline Tensor & Tensor::detach_() const { - return TypeDefault::detach_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::detach_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline int64_t Tensor::size(int64_t dim) const { -@@ -1928,7 +1968,7 @@ inline int64_t Tensor::size(int64_t dim) const { - return TypeDefault::size(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::size.int(Tensor self, int dim) -> int"); -- return table->getOp(type_set())(const_cast(*this), dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1937,7 +1977,7 @@ inline int64_t Tensor::size(Dimname dim) const { - return TypeDefault::size(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::size.Dimname(Tensor self, Dimname dim) -> int"); -- return table->getOp(type_set())(const_cast(*this), dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - #endif -@@ -1946,7 +1986,7 @@ inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t ste - return TypeDefault::slice(const_cast(*this), dim, start, end, step); - #else - static auto table = globalATenDispatch().getOpTable("aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dim, start, end, step); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, start, end, step); - #endif - } - inline std::tuple Tensor::slogdet() const { -@@ -1954,7 +1994,7 @@ inline std::tuple Tensor::slogdet() const { - return TypeDefault::slogdet(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)"); -- return table->getOp (const Tensor &)>(type_set())(const_cast(*this)); -+ return table->getOp (const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::smm(const Tensor & mat2) const { -@@ -1962,7 +2002,7 @@ inline Tensor Tensor::smm(const Tensor & mat2) const { - return TypeDefault::smm(const_cast(*this), mat2); - #else - static auto table = globalATenDispatch().getOpTable("aten::smm(Tensor self, Tensor mat2) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mat2); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat2))(const_cast(*this), mat2); - #endif - } - inline Tensor Tensor::softmax(int64_t dim, c10::optional dtype) const { -@@ -1970,7 +2010,7 @@ inline Tensor Tensor::softmax(int64_t dim, c10::optional dtype) cons - return TypeDefault::softmax(const_cast(*this), dim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -1979,7 +2019,7 @@ inline Tensor Tensor::softmax(Dimname dim, c10::optional dtype) cons - return TypeDefault::softmax(const_cast(*this), dim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); - #endif - } - #endif -@@ -1988,7 +2028,7 @@ inline std::vector Tensor::split(int64_t split_size, int64_t dim) const - return TypeDefault::split(const_cast(*this), split_size, dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[]"); -- return table->getOp (const Tensor &, int64_t, int64_t)>(type_set())(const_cast(*this), split_size, dim); -+ return table->getOp (const Tensor &, int64_t, int64_t)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), split_size, dim); - #endif - } - inline std::vector Tensor::split_with_sizes(IntArrayRef split_sizes, int64_t dim) const { -@@ -1996,7 +2036,7 @@ inline std::vector Tensor::split_with_sizes(IntArrayRef split_sizes, int - return TypeDefault::split_with_sizes(const_cast(*this), split_sizes, dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]"); -- return table->getOp (const Tensor &, IntArrayRef, int64_t)>(type_set())(const_cast(*this), split_sizes, dim); -+ return table->getOp (const Tensor &, IntArrayRef, int64_t)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), split_sizes, dim); - #endif - } - inline Tensor Tensor::squeeze() const { -@@ -2004,7 +2044,7 @@ inline Tensor Tensor::squeeze() const { - return TypeDefault::squeeze(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::squeeze(Tensor(a) self) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::squeeze(int64_t dim) const { -@@ -2012,7 +2052,7 @@ inline Tensor Tensor::squeeze(int64_t dim) const { - return TypeDefault::squeeze(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - inline Tensor & Tensor::squeeze_() const { -@@ -2020,7 +2060,7 @@ inline Tensor & Tensor::squeeze_() const { - return TypeDefault::squeeze_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::squeeze_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::squeeze_(int64_t dim) const { -@@ -2028,7 +2068,7 @@ inline Tensor & Tensor::squeeze_(int64_t dim) const { - return TypeDefault::squeeze_(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - inline Tensor Tensor::sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { -@@ -2036,7 +2076,7 @@ inline Tensor Tensor::sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar - return TypeDefault::sspaddmm(const_cast(*this), mat1, mat2, beta, alpha); - #else - static auto table = globalATenDispatch().getOpTable("aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mat1, mat2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))(const_cast(*this), mat1, mat2, beta, alpha); - #endif - } - inline Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10::optional win_length, const Tensor & window, bool normalized, bool onesided) const { -@@ -2044,7 +2084,7 @@ inline Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10 - return TypeDefault::stft(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); - #else - static auto table = globalATenDispatch().getOpTable("aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool onesided=True) -> Tensor"); -- return table->getOp, c10::optional, const Tensor &, bool, bool)>(type_set())(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); -+ return table->getOp, c10::optional, const Tensor &, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, window))(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); - #endif - } - inline int64_t Tensor::stride(int64_t dim) const { -@@ -2052,7 +2092,7 @@ inline int64_t Tensor::stride(int64_t dim) const { - return TypeDefault::stride(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::stride.int(Tensor self, int dim) -> int"); -- return table->getOp(type_set())(const_cast(*this), dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -2061,7 +2101,7 @@ inline int64_t Tensor::stride(Dimname dim) const { - return TypeDefault::stride(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::stride.Dimname(Tensor self, Dimname dim) -> int"); -- return table->getOp(type_set())(const_cast(*this), dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - #endif -@@ -2070,7 +2110,7 @@ inline Tensor Tensor::sum(c10::optional dtype) const { - return TypeDefault::sum(const_cast(*this), dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); - #endif - } - inline Tensor Tensor::sum(IntArrayRef dim, bool keepdim, c10::optional dtype) const { -@@ -2078,7 +2118,7 @@ inline Tensor Tensor::sum(IntArrayRef dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -2087,7 +2127,7 @@ inline Tensor Tensor::sum(DimnameList dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); - #endif - } - #endif -@@ -2096,7 +2136,7 @@ inline Tensor Tensor::sum_to_size(IntArrayRef size) const { - return TypeDefault::sum_to_size(const_cast(*this), size); - #else - static auto table = globalATenDispatch().getOpTable("aten::sum_to_size(Tensor self, int[] size) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), size); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size); - #endif - } - inline Tensor Tensor::sqrt() const { -@@ -2104,7 +2144,7 @@ inline Tensor Tensor::sqrt() const { - return TypeDefault::sqrt(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::sqrt(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::sqrt_() const { -@@ -2118,7 +2158,7 @@ inline Tensor & Tensor::sqrt_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sqrt_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::std(bool unbiased) const { -@@ -2126,7 +2166,7 @@ inline Tensor Tensor::std(bool unbiased) const { - return TypeDefault::std(const_cast(*this), unbiased); - #else - static auto table = globalATenDispatch().getOpTable("aten::std(Tensor self, bool unbiased=True) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), unbiased); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), unbiased); - #endif - } - inline Tensor Tensor::std(IntArrayRef dim, bool unbiased, bool keepdim) const { -@@ -2134,7 +2174,7 @@ inline Tensor Tensor::std(IntArrayRef dim, bool unbiased, bool keepdim) const { - return TypeDefault::std(const_cast(*this), dim, unbiased, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::std.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, unbiased, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -2143,7 +2183,7 @@ inline Tensor Tensor::std(DimnameList dim, bool unbiased, bool keepdim) const { - return TypeDefault::std(const_cast(*this), dim, unbiased, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, unbiased, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); - #endif - } - #endif -@@ -2152,7 +2192,7 @@ inline Tensor Tensor::prod(c10::optional dtype) const { - return TypeDefault::prod(const_cast(*this), dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); - #endif - } - inline Tensor Tensor::prod(int64_t dim, bool keepdim, c10::optional dtype) const { -@@ -2160,7 +2200,7 @@ inline Tensor Tensor::prod(int64_t dim, bool keepdim, c10::optional - return TypeDefault::prod(const_cast(*this), dim, keepdim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -2169,7 +2209,7 @@ inline Tensor Tensor::prod(Dimname dim, bool keepdim, c10::optional - return TypeDefault::prod(const_cast(*this), dim, keepdim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), dim, keepdim, dtype); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); - #endif - } - #endif -@@ -2178,7 +2218,7 @@ inline Tensor Tensor::t() const { - return TypeDefault::t(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::t(Tensor(a) self) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::t_() const { -@@ -2186,7 +2226,7 @@ inline Tensor & Tensor::t_() const { - return TypeDefault::t_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::t_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::tan() const { -@@ -2194,7 +2234,7 @@ inline Tensor Tensor::tan() const { - return TypeDefault::tan(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::tan(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::tan_() const { -@@ -2208,7 +2248,7 @@ inline Tensor & Tensor::tan_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::tan_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::tanh() const { -@@ -2216,7 +2256,7 @@ inline Tensor Tensor::tanh() const { - return TypeDefault::tanh(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::tanh(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::tanh_() const { -@@ -2230,7 +2270,7 @@ inline Tensor & Tensor::tanh_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::tanh_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const { -@@ -2238,7 +2278,7 @@ inline Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const { - return TypeDefault::transpose(const_cast(*this), dim0, dim1); - #else - static auto table = globalATenDispatch().getOpTable("aten::transpose(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dim0, dim1); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim0, dim1); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -2247,7 +2287,7 @@ inline Tensor Tensor::transpose(Dimname dim0, Dimname dim1) const { - return TypeDefault::transpose(const_cast(*this), dim0, dim1); - #else - static auto table = globalATenDispatch().getOpTable("aten::transpose(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dim0, dim1); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim0, dim1); - #endif - } - #endif -@@ -2256,7 +2296,7 @@ inline Tensor & Tensor::transpose_(int64_t dim0, int64_t dim1) const { - return TypeDefault::transpose_(const_cast(*this), dim0, dim1); - #else - static auto table = globalATenDispatch().getOpTable("aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim0, dim1); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim0, dim1); - #endif - } - inline Tensor Tensor::flip(IntArrayRef dims) const { -@@ -2270,7 +2310,7 @@ inline Tensor Tensor::flip(IntArrayRef dims) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::flip(Tensor self, int[] dims) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dims); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dims); - #endif - } - inline Tensor Tensor::roll(IntArrayRef shifts, IntArrayRef dims) const { -@@ -2284,7 +2324,7 @@ inline Tensor Tensor::roll(IntArrayRef shifts, IntArrayRef dims) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), shifts, dims); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), shifts, dims); - #endif - } - inline Tensor Tensor::rot90(int64_t k, IntArrayRef dims) const { -@@ -2292,7 +2332,7 @@ inline Tensor Tensor::rot90(int64_t k, IntArrayRef dims) const { - return TypeDefault::rot90(const_cast(*this), k, dims); - #else - static auto table = globalATenDispatch().getOpTable("aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), k, dims); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), k, dims); - #endif - } - inline Tensor Tensor::trunc() const { -@@ -2300,7 +2340,7 @@ inline Tensor Tensor::trunc() const { - return TypeDefault::trunc(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::trunc(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::trunc_() const { -@@ -2314,7 +2354,7 @@ inline Tensor & Tensor::trunc_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::trunc_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::type_as(const Tensor & other) const { -@@ -2322,7 +2362,7 @@ inline Tensor Tensor::type_as(const Tensor & other) const { - return TypeDefault::type_as(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::type_as(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::unsqueeze(int64_t dim) const { -@@ -2330,7 +2370,7 @@ inline Tensor Tensor::unsqueeze(int64_t dim) const { - return TypeDefault::unsqueeze(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - inline Tensor & Tensor::unsqueeze_(int64_t dim) const { -@@ -2338,7 +2378,7 @@ inline Tensor & Tensor::unsqueeze_(int64_t dim) const { - return TypeDefault::unsqueeze_(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - inline Tensor Tensor::var(bool unbiased) const { -@@ -2346,7 +2386,7 @@ inline Tensor Tensor::var(bool unbiased) const { - return TypeDefault::var(const_cast(*this), unbiased); - #else - static auto table = globalATenDispatch().getOpTable("aten::var(Tensor self, bool unbiased=True) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), unbiased); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), unbiased); - #endif - } - inline Tensor Tensor::var(IntArrayRef dim, bool unbiased, bool keepdim) const { -@@ -2354,7 +2394,7 @@ inline Tensor Tensor::var(IntArrayRef dim, bool unbiased, bool keepdim) const { - return TypeDefault::var(const_cast(*this), dim, unbiased, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::var.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, unbiased, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -2363,7 +2403,7 @@ inline Tensor Tensor::var(DimnameList dim, bool unbiased, bool keepdim) const { - return TypeDefault::var(const_cast(*this), dim, unbiased, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, unbiased, keepdim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); - #endif - } - #endif -@@ -2372,7 +2412,7 @@ inline Tensor Tensor::view_as(const Tensor & other) const { - return TypeDefault::view_as(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::view_as(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) const { -@@ -2380,7 +2420,7 @@ inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) cons - return TypeDefault::where(condition, const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(condition, const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(condition, *this, other))(condition, const_cast(*this), other); - #endif - } - inline Tensor Tensor::norm(c10::optional p, ScalarType dtype) const { -@@ -2388,7 +2428,7 @@ inline Tensor Tensor::norm(c10::optional p, ScalarType dtype) const { - return TypeDefault::norm(const_cast(*this), p, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor"); -- return table->getOp, ScalarType)>(type_set())(const_cast(*this), p, dtype); -+ return table->getOp, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dtype); - #endif - } - inline Tensor Tensor::norm(Scalar p) const { -@@ -2396,7 +2436,7 @@ inline Tensor Tensor::norm(Scalar p) const { - return TypeDefault::norm(const_cast(*this), p); - #else - static auto table = globalATenDispatch().getOpTable("aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), p); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p); - #endif - } - inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdim, ScalarType dtype) const { -@@ -2404,7 +2444,7 @@ inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdi - return TypeDefault::norm(const_cast(*this), p, dim, keepdim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"); -- return table->getOp, IntArrayRef, bool, ScalarType)>(type_set())(const_cast(*this), p, dim, keepdim, dtype); -+ return table->getOp, IntArrayRef, bool, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim, dtype); - #endif - } - inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdim) const { -@@ -2412,7 +2452,7 @@ inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdi - return TypeDefault::norm(const_cast(*this), p, dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor"); -- return table->getOp, IntArrayRef, bool)>(type_set())(const_cast(*this), p, dim, keepdim); -+ return table->getOp, IntArrayRef, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -2421,7 +2461,7 @@ inline Tensor Tensor::norm(c10::optional p, DimnameList dim, bool keepdi - return TypeDefault::norm(const_cast(*this), p, dim, keepdim, dtype); - #else - static auto table = globalATenDispatch().getOpTable("aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"); -- return table->getOp, DimnameList, bool, ScalarType)>(type_set())(const_cast(*this), p, dim, keepdim, dtype); -+ return table->getOp, DimnameList, bool, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim, dtype); - #endif - } - #endif -@@ -2431,7 +2471,7 @@ inline Tensor Tensor::norm(c10::optional p, DimnameList dim, bool keepdi - return TypeDefault::norm(const_cast(*this), p, dim, keepdim); - #else - static auto table = globalATenDispatch().getOpTable("aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor"); -- return table->getOp, DimnameList, bool)>(type_set())(const_cast(*this), p, dim, keepdim); -+ return table->getOp, DimnameList, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim); - #endif - } - #endif -@@ -2452,7 +2492,7 @@ inline Tensor Tensor::clone() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::clone(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::resize_as_(const Tensor & the_template) const { -@@ -2469,7 +2509,7 @@ inline Tensor & Tensor::resize_as_(const Tensor & the_template) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::resize_as_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), the_template); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, the_template))(const_cast(*this), the_template); - #endif - } - inline Tensor Tensor::pow(Scalar exponent) const { -@@ -2486,7 +2526,7 @@ inline Tensor Tensor::pow(Scalar exponent) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), exponent); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), exponent); - #endif - } - inline Tensor & Tensor::zero_() const { -@@ -2503,23 +2543,41 @@ inline Tensor & Tensor::zero_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::zero_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::sub(const Tensor & other, Scalar alpha) const { - #ifdef USE_STATIC_DISPATCH -- return TypeDefault::sub(const_cast(*this), other, alpha); -+ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { -+ case Backend::CPU: -+ return CPUType::sub(const_cast(*this), other, alpha); -+ break; -+ case Backend::SparseCPU: -+ return SparseCPUType::sub(const_cast(*this), other, alpha); -+ break; -+ default: -+ AT_ERROR("sub not implemented for ", at::toString(type_set())); -+ } - #else - static auto table = globalATenDispatch().getOpTable("aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, alpha); - #endif - } - inline Tensor & Tensor::sub_(const Tensor & other, Scalar alpha) const { - #ifdef USE_STATIC_DISPATCH -- return TypeDefault::sub_(const_cast(*this), other, alpha); -+ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { -+ case Backend::CPU: -+ return CPUType::sub_(const_cast(*this), other, alpha); -+ break; -+ case Backend::SparseCPU: -+ return SparseCPUType::sub_(const_cast(*this), other, alpha); -+ break; -+ default: -+ AT_ERROR("sub_ not implemented for ", at::toString(type_set())); -+ } - #else - static auto table = globalATenDispatch().getOpTable("aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, alpha); - #endif - } - inline Tensor Tensor::sub(Scalar other, Scalar alpha) const { -@@ -2527,7 +2585,7 @@ inline Tensor Tensor::sub(Scalar other, Scalar alpha) const { - return TypeDefault::sub(const_cast(*this), other, alpha); - #else - static auto table = globalATenDispatch().getOpTable("aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other, alpha); - #endif - } - inline Tensor & Tensor::sub_(Scalar other, Scalar alpha) const { -@@ -2535,23 +2593,41 @@ inline Tensor & Tensor::sub_(Scalar other, Scalar alpha) const { - return TypeDefault::sub_(const_cast(*this), other, alpha); - #else - static auto table = globalATenDispatch().getOpTable("aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other, alpha); - #endif - } - inline Tensor Tensor::addmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { - #ifdef USE_STATIC_DISPATCH -- return TypeDefault::addmm(const_cast(*this), mat1, mat2, beta, alpha); -+ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { -+ case Backend::CPU: -+ return CPUType::addmm(const_cast(*this), mat1, mat2, beta, alpha); -+ break; -+ case Backend::SparseCPU: -+ return SparseCPUType::addmm(const_cast(*this), mat1, mat2, beta, alpha); -+ break; -+ default: -+ AT_ERROR("addmm not implemented for ", at::toString(type_set())); -+ } - #else - static auto table = globalATenDispatch().getOpTable("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mat1, mat2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))(const_cast(*this), mat1, mat2, beta, alpha); - #endif - } - inline Tensor & Tensor::addmm_(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { - #ifdef USE_STATIC_DISPATCH -- return TypeDefault::addmm_(const_cast(*this), mat1, mat2, beta, alpha); -+ switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { -+ case Backend::CPU: -+ return CPUType::addmm_(const_cast(*this), mat1, mat2, beta, alpha); -+ break; -+ case Backend::SparseCPU: -+ return SparseCPUType::addmm_(const_cast(*this), mat1, mat2, beta, alpha); -+ break; -+ default: -+ AT_ERROR("addmm_ not implemented for ", at::toString(type_set())); -+ } - #else - static auto table = globalATenDispatch().getOpTable("aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), mat1, mat2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))(const_cast(*this), mat1, mat2, beta, alpha); - #endif - } - inline Tensor & Tensor::sparse_resize_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { -@@ -2565,7 +2641,7 @@ inline Tensor & Tensor::sparse_resize_(IntArrayRef size, int64_t sparse_dim, int - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), size, sparse_dim, dense_dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, sparse_dim, dense_dim); - #endif - } - inline Tensor & Tensor::sparse_resize_and_clear_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { -@@ -2579,21 +2655,21 @@ inline Tensor & Tensor::sparse_resize_and_clear_(IntArrayRef size, int64_t spars - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), size, sparse_dim, dense_dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size, sparse_dim, dense_dim); - #endif - } - inline Tensor Tensor::sparse_mask(const Tensor & mask) const { - #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { -- case Backend::CPU: -- return CPUType::sparse_mask(const_cast(*this), mask); -+ case Backend::SparseCPU: -+ return SparseCPUType::sparse_mask(const_cast(*this), mask); - break; - default: - AT_ERROR("sparse_mask not implemented for ", at::toString(type_set())); - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sparse_mask(Tensor self, Tensor mask) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mask); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask))(const_cast(*this), mask); - #endif - } - inline Tensor Tensor::to_dense() const { -@@ -2607,7 +2683,7 @@ inline Tensor Tensor::to_dense() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::to_dense(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline int64_t Tensor::sparse_dim() const { -@@ -2621,7 +2697,7 @@ inline int64_t Tensor::sparse_dim() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sparse_dim(Tensor self) -> int"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline int64_t Tensor::_dimI() const { -@@ -2635,7 +2711,7 @@ inline int64_t Tensor::_dimI() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::_dimI(Tensor self) -> int"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline int64_t Tensor::dense_dim() const { -@@ -2649,7 +2725,7 @@ inline int64_t Tensor::dense_dim() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::dense_dim(Tensor self) -> int"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline int64_t Tensor::_dimV() const { -@@ -2663,7 +2739,7 @@ inline int64_t Tensor::_dimV() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::_dimV(Tensor self) -> int"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline int64_t Tensor::_nnz() const { -@@ -2677,7 +2753,7 @@ inline int64_t Tensor::_nnz() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::_nnz(Tensor self) -> int"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::coalesce() const { -@@ -2691,7 +2767,7 @@ inline Tensor Tensor::coalesce() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::coalesce(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline bool Tensor::is_coalesced() const { -@@ -2705,7 +2781,7 @@ inline bool Tensor::is_coalesced() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::is_coalesced(Tensor self) -> bool"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::_indices() const { -@@ -2719,7 +2795,7 @@ inline Tensor Tensor::_indices() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::_indices(Tensor(a) self) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::_values() const { -@@ -2733,7 +2809,7 @@ inline Tensor Tensor::_values() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::_values(Tensor(a) self) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::_coalesced_(bool coalesced) const { -@@ -2747,7 +2823,7 @@ inline Tensor & Tensor::_coalesced_(bool coalesced) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), coalesced); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), coalesced); - #endif - } - inline Tensor Tensor::indices() const { -@@ -2761,7 +2837,7 @@ inline Tensor Tensor::indices() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::indices(Tensor(a) self) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::values() const { -@@ -2775,7 +2851,7 @@ inline Tensor Tensor::values() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::values(Tensor(a) self) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline int64_t Tensor::numel() const { -@@ -2783,7 +2859,7 @@ inline int64_t Tensor::numel() const { - return TypeDefault::numel(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::numel(Tensor self) -> int"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline std::vector Tensor::unbind(int64_t dim) const { -@@ -2791,7 +2867,7 @@ inline std::vector Tensor::unbind(int64_t dim) const { - return TypeDefault::unbind(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::unbind(Tensor(a) self, int dim=0) -> Tensor(a)[]"); -- return table->getOp (const Tensor &, int64_t)>(type_set())(const_cast(*this), dim); -+ return table->getOp (const Tensor &, int64_t)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - #ifdef BUILD_NAMEDTENSOR -@@ -2800,7 +2876,7 @@ inline std::vector Tensor::unbind(Dimname dim) const { - return TypeDefault::unbind(const_cast(*this), dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::unbind(Tensor(a) self, Dimname dim) -> Tensor(a)[]"); -- return table->getOp (const Tensor &, Dimname)>(type_set())(const_cast(*this), dim); -+ return table->getOp (const Tensor &, Dimname)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); - #endif - } - #endif -@@ -2815,7 +2891,7 @@ inline Tensor Tensor::to_sparse(int64_t sparse_dim) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), sparse_dim); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), sparse_dim); - #endif - } - inline Tensor Tensor::to_sparse() const { -@@ -2829,7 +2905,7 @@ inline Tensor Tensor::to_sparse() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::to_sparse(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::to_mkldnn() const { -@@ -2843,7 +2919,7 @@ inline Tensor Tensor::to_mkldnn() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::to_mkldnn(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::dequantize() const { -@@ -2857,7 +2933,7 @@ inline Tensor Tensor::dequantize() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::dequantize(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline double Tensor::q_scale() const { -@@ -2871,7 +2947,7 @@ inline double Tensor::q_scale() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::q_scale(Tensor self) -> float"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline int64_t Tensor::q_zero_point() const { -@@ -2885,7 +2961,7 @@ inline int64_t Tensor::q_zero_point() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::q_zero_point(Tensor self) -> int"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::q_per_channel_scales() const { -@@ -2899,7 +2975,7 @@ inline Tensor Tensor::q_per_channel_scales() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_scales(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::q_per_channel_zero_points() const { -@@ -2913,7 +2989,7 @@ inline Tensor Tensor::q_per_channel_zero_points() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_zero_points(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::int_repr() const { -@@ -2927,7 +3003,7 @@ inline Tensor Tensor::int_repr() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::int_repr(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline QScheme Tensor::qscheme() const { -@@ -2941,7 +3017,7 @@ inline QScheme Tensor::qscheme() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::qscheme(Tensor self) -> QScheme"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool copy) const { -@@ -2949,7 +3025,7 @@ inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool - return TypeDefault::to(const_cast(*this), options, non_blocking, copy); - #else - static auto table = globalATenDispatch().getOpTable("aten::to.dtype_layout(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), options, non_blocking, copy); -+ return table->getOp(type_set(/* HMMMM */))(const_cast(*this), options, non_blocking, copy); - #endif - } - inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, bool copy) const { -@@ -2957,7 +3033,7 @@ inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, boo - return TypeDefault::to(const_cast(*this), device, dtype, non_blocking, copy); - #else - static auto table = globalATenDispatch().getOpTable("aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), device, dtype, non_blocking, copy); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), device, dtype, non_blocking, copy); - #endif - } - inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const { -@@ -2965,7 +3041,7 @@ inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const { - return TypeDefault::to(const_cast(*this), dtype, non_blocking, copy); - #else - static auto table = globalATenDispatch().getOpTable("aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dtype, non_blocking, copy); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype, non_blocking, copy); - #endif - } - inline Tensor Tensor::to(const Tensor & other, bool non_blocking, bool copy) const { -@@ -2973,7 +3049,7 @@ inline Tensor Tensor::to(const Tensor & other, bool non_blocking, bool copy) con - return TypeDefault::to(const_cast(*this), other, non_blocking, copy); - #else - static auto table = globalATenDispatch().getOpTable("aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other, non_blocking, copy); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, non_blocking, copy); - #endif - } - inline Scalar Tensor::item() const { -@@ -2981,7 +3057,7 @@ inline Scalar Tensor::item() const { - return TypeDefault::item(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::item(Tensor self) -> Scalar"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::set_(Storage source) const { -@@ -2995,7 +3071,7 @@ inline Tensor & Tensor::set_(Storage source) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), source); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), source); - #endif - } - inline Tensor & Tensor::set_(Storage source, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) const { -@@ -3012,7 +3088,7 @@ inline Tensor & Tensor::set_(Storage source, int64_t storage_offset, IntArrayRef - } - #else - static auto table = globalATenDispatch().getOpTable("aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), source, storage_offset, size, stride); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), source, storage_offset, size, stride); - #endif - } - inline Tensor & Tensor::set_(const Tensor & source) const { -@@ -3026,7 +3102,7 @@ inline Tensor & Tensor::set_(const Tensor & source) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), source); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, source))(const_cast(*this), source); - #endif - } - inline Tensor & Tensor::set_() const { -@@ -3040,7 +3116,7 @@ inline Tensor & Tensor::set_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::set_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::set_quantizer_(ConstQuantizerPtr quantizer) const { -@@ -3054,7 +3130,7 @@ inline Tensor & Tensor::set_quantizer_(ConstQuantizerPtr quantizer) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::set_quantizer_(Tensor(a!) self, ConstQuantizerPtr quantizer) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), quantizer); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), quantizer); - #endif - } - inline bool Tensor::is_set_to(const Tensor & tensor) const { -@@ -3068,7 +3144,7 @@ inline bool Tensor::is_set_to(const Tensor & tensor) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::is_set_to(Tensor self, Tensor tensor) -> bool"); -- return table->getOp(type_set())(const_cast(*this), tensor); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor))(const_cast(*this), tensor); - #endif - } - inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) const { -@@ -3082,7 +3158,7 @@ inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), mask, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask))(const_cast(*this), mask, value); - #endif - } - inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const { -@@ -3090,7 +3166,7 @@ inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const { - return TypeDefault::masked_fill(const_cast(*this), mask, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mask, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask))(const_cast(*this), mask, value); - #endif - } - inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) const { -@@ -3104,7 +3180,7 @@ inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) - } - #else - static auto table = globalATenDispatch().getOpTable("aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), mask, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask, value))(const_cast(*this), mask, value); - #endif - } - inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) const { -@@ -3112,7 +3188,7 @@ inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) con - return TypeDefault::masked_fill(const_cast(*this), mask, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mask, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask, value))(const_cast(*this), mask, value); - #endif - } - inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & source) const { -@@ -3126,7 +3202,7 @@ inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & sour - } - #else - static auto table = globalATenDispatch().getOpTable("aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), mask, source); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask, source))(const_cast(*this), mask, source); - #endif - } - inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const { -@@ -3134,7 +3210,7 @@ inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) - return TypeDefault::masked_scatter(const_cast(*this), mask, source); - #else - static auto table = globalATenDispatch().getOpTable("aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mask, source); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask, source))(const_cast(*this), mask, source); - #endif - } - inline Tensor Tensor::view(IntArrayRef size) const { -@@ -3151,7 +3227,7 @@ inline Tensor Tensor::view(IntArrayRef size) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::view(Tensor(a) self, int[] size) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), size); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), size); - #endif - } - inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool accumulate) const { -@@ -3165,7 +3241,7 @@ inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool a - } - #else - static auto table = globalATenDispatch().getOpTable("aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), index, source, accumulate); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), index, source, accumulate); - #endif - } - inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tensor & source) const { -@@ -3179,7 +3255,7 @@ inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tens - } - #else - static auto table = globalATenDispatch().getOpTable("aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim, index, source); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), dim, index, source); - #endif - } - inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const { -@@ -3187,7 +3263,7 @@ inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor - return TypeDefault::index_add(const_cast(*this), dim, index, source); - #else - static auto table = globalATenDispatch().getOpTable("aten::index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index, source); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, source))(const_cast(*this), dim, index, source); - #endif - } - inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar value) const { -@@ -3201,7 +3277,7 @@ inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar va - } - #else - static auto table = globalATenDispatch().getOpTable("aten::index_fill_.Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim, index, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, value); - #endif - } - inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value) const { -@@ -3209,7 +3285,7 @@ inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value - return TypeDefault::index_fill(const_cast(*this), dim, index, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::index_fill.Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, value); - #endif - } - inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Tensor & value) const { -@@ -3223,7 +3299,7 @@ inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Ten - } - #else - static auto table = globalATenDispatch().getOpTable("aten::index_fill_.Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim, index, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, value))(const_cast(*this), dim, index, value); - #endif - } - inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor & value) const { -@@ -3231,7 +3307,7 @@ inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor - return TypeDefault::index_fill(const_cast(*this), dim, index, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::index_fill.Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, value))(const_cast(*this), dim, index, value); - #endif - } - inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor & src) const { -@@ -3245,7 +3321,7 @@ inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor - } - #else - static auto table = globalATenDispatch().getOpTable("aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim, index, src); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, src))(const_cast(*this), dim, index, src); - #endif - } - inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & src) const { -@@ -3253,7 +3329,7 @@ inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & - return TypeDefault::scatter(const_cast(*this), dim, index, src); - #else - static auto table = globalATenDispatch().getOpTable("aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index, src); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, src))(const_cast(*this), dim, index, src); - #endif - } - inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value) const { -@@ -3267,7 +3343,7 @@ inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value - } - #else - static auto table = globalATenDispatch().getOpTable("aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim, index, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, value); - #endif - } - inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) const { -@@ -3275,7 +3351,7 @@ inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) c - return TypeDefault::scatter(const_cast(*this), dim, index, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, value); - #endif - } - inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Tensor & src) const { -@@ -3289,7 +3365,7 @@ inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Te - } - #else - static auto table = globalATenDispatch().getOpTable("aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), dim, index, src); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, src))(const_cast(*this), dim, index, src); - #endif - } - inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const { -@@ -3297,7 +3373,7 @@ inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tenso - return TypeDefault::scatter_add(const_cast(*this), dim, index, src); - #else - static auto table = globalATenDispatch().getOpTable("aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index, src); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index, src))(const_cast(*this), dim, index, src); - #endif - } - inline Tensor & Tensor::lt_(Scalar other) const { -@@ -3311,7 +3387,7 @@ inline Tensor & Tensor::lt_(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::lt_(const Tensor & other) const { -@@ -3325,7 +3401,7 @@ inline Tensor & Tensor::lt_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::gt_(Scalar other) const { -@@ -3339,7 +3415,7 @@ inline Tensor & Tensor::gt_(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::gt_(const Tensor & other) const { -@@ -3353,7 +3429,7 @@ inline Tensor & Tensor::gt_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::le_(Scalar other) const { -@@ -3367,7 +3443,7 @@ inline Tensor & Tensor::le_(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::le_(const Tensor & other) const { -@@ -3381,7 +3457,7 @@ inline Tensor & Tensor::le_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::ge_(Scalar other) const { -@@ -3395,7 +3471,7 @@ inline Tensor & Tensor::ge_(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::ge_(const Tensor & other) const { -@@ -3409,7 +3485,7 @@ inline Tensor & Tensor::ge_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::eq_(Scalar other) const { -@@ -3423,7 +3499,7 @@ inline Tensor & Tensor::eq_(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::eq_(const Tensor & other) const { -@@ -3437,7 +3513,7 @@ inline Tensor & Tensor::eq_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::ne_(Scalar other) const { -@@ -3451,7 +3527,7 @@ inline Tensor & Tensor::ne_(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::ne_(const Tensor & other) const { -@@ -3465,7 +3541,7 @@ inline Tensor & Tensor::ne_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__and__(Scalar other) const { -@@ -3479,7 +3555,7 @@ inline Tensor Tensor::__and__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__and__(const Tensor & other) const { -@@ -3493,7 +3569,7 @@ inline Tensor Tensor::__and__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__iand__(Scalar other) const { -@@ -3507,7 +3583,7 @@ inline Tensor & Tensor::__iand__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__iand__(const Tensor & other) const { -@@ -3521,7 +3597,7 @@ inline Tensor & Tensor::__iand__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__or__(Scalar other) const { -@@ -3535,7 +3611,7 @@ inline Tensor Tensor::__or__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__or__(const Tensor & other) const { -@@ -3549,7 +3625,7 @@ inline Tensor Tensor::__or__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__ior__(Scalar other) const { -@@ -3563,7 +3639,7 @@ inline Tensor & Tensor::__ior__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__ior__(const Tensor & other) const { -@@ -3577,7 +3653,7 @@ inline Tensor & Tensor::__ior__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__xor__(Scalar other) const { -@@ -3591,7 +3667,7 @@ inline Tensor Tensor::__xor__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__xor__(const Tensor & other) const { -@@ -3605,7 +3681,7 @@ inline Tensor Tensor::__xor__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__ixor__(Scalar other) const { -@@ -3619,7 +3695,7 @@ inline Tensor & Tensor::__ixor__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__ixor__(const Tensor & other) const { -@@ -3633,7 +3709,7 @@ inline Tensor & Tensor::__ixor__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__lshift__(Scalar other) const { -@@ -3647,7 +3723,7 @@ inline Tensor Tensor::__lshift__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__lshift__(const Tensor & other) const { -@@ -3661,7 +3737,7 @@ inline Tensor Tensor::__lshift__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__ilshift__(Scalar other) const { -@@ -3675,7 +3751,7 @@ inline Tensor & Tensor::__ilshift__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__ilshift__(const Tensor & other) const { -@@ -3689,7 +3765,7 @@ inline Tensor & Tensor::__ilshift__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__rshift__(Scalar other) const { -@@ -3703,7 +3779,7 @@ inline Tensor Tensor::__rshift__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::__rshift__(const Tensor & other) const { -@@ -3717,7 +3793,7 @@ inline Tensor Tensor::__rshift__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__irshift__(Scalar other) const { -@@ -3731,7 +3807,7 @@ inline Tensor & Tensor::__irshift__(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::__irshift__(const Tensor & other) const { -@@ -3745,7 +3821,7 @@ inline Tensor & Tensor::__irshift__(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::lgamma_() const { -@@ -3759,7 +3835,7 @@ inline Tensor & Tensor::lgamma_() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lgamma_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::atan2_(const Tensor & other) const { -@@ -3767,7 +3843,7 @@ inline Tensor & Tensor::atan2_(const Tensor & other) const { - return TypeDefault::atan2_(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::tril_(int64_t diagonal) const { -@@ -3781,7 +3857,7 @@ inline Tensor & Tensor::tril_(int64_t diagonal) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), diagonal); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); - #endif - } - inline Tensor & Tensor::triu_(int64_t diagonal) const { -@@ -3795,7 +3871,7 @@ inline Tensor & Tensor::triu_(int64_t diagonal) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), diagonal); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); - #endif - } - inline Tensor & Tensor::digamma_() const { -@@ -3803,7 +3879,7 @@ inline Tensor & Tensor::digamma_() const { - return TypeDefault::digamma_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::digamma_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::polygamma_(int64_t n) const { -@@ -3811,7 +3887,7 @@ inline Tensor & Tensor::polygamma_(int64_t n) const { - return TypeDefault::polygamma_(const_cast(*this), n); - #else - static auto table = globalATenDispatch().getOpTable("aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), n); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), n); - #endif - } - inline Tensor & Tensor::renorm_(Scalar p, int64_t dim, Scalar maxnorm) const { -@@ -3825,7 +3901,7 @@ inline Tensor & Tensor::renorm_(Scalar p, int64_t dim, Scalar maxnorm) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), p, dim, maxnorm); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, maxnorm); - #endif - } - inline Tensor & Tensor::pow_(Scalar exponent) const { -@@ -3839,7 +3915,7 @@ inline Tensor & Tensor::pow_(Scalar exponent) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), exponent); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), exponent); - #endif - } - inline Tensor & Tensor::pow_(const Tensor & exponent) const { -@@ -3853,7 +3929,7 @@ inline Tensor & Tensor::pow_(const Tensor & exponent) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), exponent); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, exponent))(const_cast(*this), exponent); - #endif - } - inline Tensor & Tensor::lerp_(const Tensor & end, Scalar weight) const { -@@ -3867,7 +3943,7 @@ inline Tensor & Tensor::lerp_(const Tensor & end, Scalar weight) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), end, weight); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, end))(const_cast(*this), end, weight); - #endif - } - inline Tensor & Tensor::lerp_(const Tensor & end, const Tensor & weight) const { -@@ -3881,7 +3957,7 @@ inline Tensor & Tensor::lerp_(const Tensor & end, const Tensor & weight) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), end, weight); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, end, weight))(const_cast(*this), end, weight); - #endif - } - inline Tensor & Tensor::fmod_(Scalar other) const { -@@ -3895,7 +3971,7 @@ inline Tensor & Tensor::fmod_(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::fmod_(const Tensor & other) const { -@@ -3909,7 +3985,7 @@ inline Tensor & Tensor::fmod_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::remainder_(Scalar other) const { -@@ -3923,7 +3999,7 @@ inline Tensor & Tensor::remainder_(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::remainder_(const Tensor & other) const { -@@ -3937,7 +4013,7 @@ inline Tensor & Tensor::remainder_(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor & Tensor::addbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { -@@ -3951,7 +4027,7 @@ inline Tensor & Tensor::addbmm_(const Tensor & batch1, const Tensor & batch2, Sc - } - #else - static auto table = globalATenDispatch().getOpTable("aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), batch1, batch2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))(const_cast(*this), batch1, batch2, beta, alpha); - #endif - } - inline Tensor Tensor::addbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { -@@ -3965,7 +4041,7 @@ inline Tensor Tensor::addbmm(const Tensor & batch1, const Tensor & batch2, Scala - } - #else - static auto table = globalATenDispatch().getOpTable("aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), batch1, batch2, beta, alpha); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))(const_cast(*this), batch1, batch2, beta, alpha); - #endif - } - inline Tensor & Tensor::addcdiv_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { -@@ -3973,7 +4049,7 @@ inline Tensor & Tensor::addcdiv_(const Tensor & tensor1, const Tensor & tensor2, - return TypeDefault::addcdiv_(const_cast(*this), tensor1, tensor2, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), tensor1, tensor2, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))(const_cast(*this), tensor1, tensor2, value); - #endif - } - inline Tensor & Tensor::random_(int64_t from, int64_t to, Generator * generator) const { -@@ -3987,7 +4063,7 @@ inline Tensor & Tensor::random_(int64_t from, int64_t to, Generator * generator) - } - #else - static auto table = globalATenDispatch().getOpTable("aten::random_.from(Tensor(a!) self, int from, int to, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), from, to, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), from, to, generator); - #endif - } - inline Tensor & Tensor::random_(int64_t to, Generator * generator) const { -@@ -4001,7 +4077,7 @@ inline Tensor & Tensor::random_(int64_t to, Generator * generator) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), to, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), to, generator); - #endif - } - inline Tensor & Tensor::random_(Generator * generator) const { -@@ -4015,7 +4091,7 @@ inline Tensor & Tensor::random_(Generator * generator) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), generator); - #endif - } - inline Tensor & Tensor::uniform_(double from, double to, Generator * generator) const { -@@ -4029,7 +4105,7 @@ inline Tensor & Tensor::uniform_(double from, double to, Generator * generator) - } - #else - static auto table = globalATenDispatch().getOpTable("aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), from, to, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), from, to, generator); - #endif - } - inline Tensor & Tensor::normal_(double mean, double std, Generator * generator) const { -@@ -4043,7 +4119,7 @@ inline Tensor & Tensor::normal_(double mean, double std, Generator * generator) - } - #else - static auto table = globalATenDispatch().getOpTable("aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), mean, std, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), mean, std, generator); - #endif - } - inline Tensor & Tensor::cauchy_(double median, double sigma, Generator * generator) const { -@@ -4057,7 +4133,7 @@ inline Tensor & Tensor::cauchy_(double median, double sigma, Generator * generat - } - #else - static auto table = globalATenDispatch().getOpTable("aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), median, sigma, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), median, sigma, generator); - #endif - } - inline Tensor & Tensor::log_normal_(double mean, double std, Generator * generator) const { -@@ -4071,7 +4147,7 @@ inline Tensor & Tensor::log_normal_(double mean, double std, Generator * generat - } - #else - static auto table = globalATenDispatch().getOpTable("aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), mean, std, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), mean, std, generator); - #endif - } - inline Tensor & Tensor::exponential_(double lambd, Generator * generator) const { -@@ -4085,7 +4161,7 @@ inline Tensor & Tensor::exponential_(double lambd, Generator * generator) const - } - #else - static auto table = globalATenDispatch().getOpTable("aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), lambd, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), lambd, generator); - #endif - } - inline Tensor & Tensor::geometric_(double p, Generator * generator) const { -@@ -4099,7 +4175,7 @@ inline Tensor & Tensor::geometric_(double p, Generator * generator) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), p, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); - #endif - } - inline Tensor Tensor::diag(int64_t diagonal) const { -@@ -4113,7 +4189,7 @@ inline Tensor Tensor::diag(int64_t diagonal) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::diag(Tensor self, int diagonal=0) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), diagonal); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); - #endif - } - inline Tensor Tensor::cross(const Tensor & other, c10::optional dim) const { -@@ -4121,7 +4197,7 @@ inline Tensor Tensor::cross(const Tensor & other, c10::optional dim) co - return TypeDefault::cross(const_cast(*this), other, dim); - #else - static auto table = globalATenDispatch().getOpTable("aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor"); -- return table->getOp)>(type_set())(const_cast(*this), other, dim); -+ return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, dim); - #endif - } - inline Tensor Tensor::triu(int64_t diagonal) const { -@@ -4129,7 +4205,7 @@ inline Tensor Tensor::triu(int64_t diagonal) const { - return TypeDefault::triu(const_cast(*this), diagonal); - #else - static auto table = globalATenDispatch().getOpTable("aten::triu(Tensor self, int diagonal=0) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), diagonal); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); - #endif - } - inline Tensor Tensor::tril(int64_t diagonal) const { -@@ -4137,7 +4213,7 @@ inline Tensor Tensor::tril(int64_t diagonal) const { - return TypeDefault::tril(const_cast(*this), diagonal); - #else - static auto table = globalATenDispatch().getOpTable("aten::tril(Tensor self, int diagonal=0) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), diagonal); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), diagonal); - #endif - } - inline Tensor Tensor::trace() const { -@@ -4151,7 +4227,7 @@ inline Tensor Tensor::trace() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::trace(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::ne(Scalar other) const { -@@ -4168,7 +4244,7 @@ inline Tensor Tensor::ne(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ne.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::ne(const Tensor & other) const { -@@ -4185,7 +4261,7 @@ inline Tensor Tensor::ne(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ne.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::eq(Scalar other) const { -@@ -4202,7 +4278,7 @@ inline Tensor Tensor::eq(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::eq.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::eq(const Tensor & other) const { -@@ -4219,7 +4295,7 @@ inline Tensor Tensor::eq(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::eq.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::ge(Scalar other) const { -@@ -4236,7 +4312,7 @@ inline Tensor Tensor::ge(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ge.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::ge(const Tensor & other) const { -@@ -4253,7 +4329,7 @@ inline Tensor Tensor::ge(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ge.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::le(Scalar other) const { -@@ -4270,7 +4346,7 @@ inline Tensor Tensor::le(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::le.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::le(const Tensor & other) const { -@@ -4287,7 +4363,7 @@ inline Tensor Tensor::le(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::le.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::gt(Scalar other) const { -@@ -4304,7 +4380,7 @@ inline Tensor Tensor::gt(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::gt.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::gt(const Tensor & other) const { -@@ -4321,7 +4397,7 @@ inline Tensor Tensor::gt(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::gt.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::lt(Scalar other) const { -@@ -4338,7 +4414,7 @@ inline Tensor Tensor::lt(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lt.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::lt(const Tensor & other) const { -@@ -4355,7 +4431,7 @@ inline Tensor Tensor::lt(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lt.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::take(const Tensor & index) const { -@@ -4369,7 +4445,7 @@ inline Tensor Tensor::take(const Tensor & index) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::take(Tensor self, Tensor index) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), index); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), index); - #endif - } - inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const { -@@ -4386,7 +4462,7 @@ inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index); - #endif - } - inline Tensor Tensor::masked_select(const Tensor & mask) const { -@@ -4400,7 +4476,7 @@ inline Tensor Tensor::masked_select(const Tensor & mask) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::masked_select(Tensor self, Tensor mask) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), mask); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, mask))(const_cast(*this), mask); - #endif - } - inline Tensor Tensor::nonzero() const { -@@ -4414,7 +4490,7 @@ inline Tensor Tensor::nonzero() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::nonzero(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline std::vector Tensor::nonzero_numpy() const { -@@ -4422,7 +4498,7 @@ inline std::vector Tensor::nonzero_numpy() const { - return TypeDefault::nonzero_numpy(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::nonzero_numpy(Tensor self) -> Tensor[]"); -- return table->getOp (const Tensor &)>(type_set())(const_cast(*this)); -+ return table->getOp (const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad) const { -@@ -4436,7 +4512,7 @@ inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad - } - #else - static auto table = globalATenDispatch().getOpTable("aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, index, sparse_grad); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, index))(const_cast(*this), dim, index, sparse_grad); - #endif - } - inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { -@@ -4444,7 +4520,7 @@ inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Sc - return TypeDefault::addcmul(const_cast(*this), tensor1, tensor2, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), tensor1, tensor2, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))(const_cast(*this), tensor1, tensor2, value); - #endif - } - inline Tensor & Tensor::addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { -@@ -4452,7 +4528,7 @@ inline Tensor & Tensor::addcmul_(const Tensor & tensor1, const Tensor & tensor2, - return TypeDefault::addcmul_(const_cast(*this), tensor1, tensor2, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this), tensor1, tensor2, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))(const_cast(*this), tensor1, tensor2, value); - #endif - } - inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { -@@ -4460,7 +4536,7 @@ inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Sc - return TypeDefault::addcdiv(const_cast(*this), tensor1, tensor2, value); - #else - static auto table = globalATenDispatch().getOpTable("aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), tensor1, tensor2, value); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))(const_cast(*this), tensor1, tensor2, value); - #endif - } - inline std::tuple Tensor::lstsq(const Tensor & A) const { -@@ -4474,7 +4550,7 @@ inline std::tuple Tensor::lstsq(const Tensor & A) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lstsq(Tensor self, Tensor A) -> (Tensor solution, Tensor QR)"); -- return table->getOp (const Tensor &, const Tensor &)>(type_set())(const_cast(*this), A); -+ return table->getOp (const Tensor &, const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this, A))(const_cast(*this), A); - #endif - } - inline std::tuple Tensor::triangular_solve(const Tensor & A, bool upper, bool transpose, bool unitriangular) const { -@@ -4482,7 +4558,7 @@ inline std::tuple Tensor::triangular_solve(const Tensor & A, bool - return TypeDefault::triangular_solve(const_cast(*this), A, upper, transpose, unitriangular); - #else - static auto table = globalATenDispatch().getOpTable("aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)"); -- return table->getOp (const Tensor &, const Tensor &, bool, bool, bool)>(type_set())(const_cast(*this), A, upper, transpose, unitriangular); -+ return table->getOp (const Tensor &, const Tensor &, bool, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, A))(const_cast(*this), A, upper, transpose, unitriangular); - #endif - } - inline std::tuple Tensor::symeig(bool eigenvectors, bool upper) const { -@@ -4490,7 +4566,7 @@ inline std::tuple Tensor::symeig(bool eigenvectors, bool upper) c - return TypeDefault::symeig(const_cast(*this), eigenvectors, upper); - #else - static auto table = globalATenDispatch().getOpTable("aten::symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)"); -- return table->getOp (const Tensor &, bool, bool)>(type_set())(const_cast(*this), eigenvectors, upper); -+ return table->getOp (const Tensor &, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), eigenvectors, upper); - #endif - } - inline std::tuple Tensor::eig(bool eigenvectors) const { -@@ -4504,7 +4580,7 @@ inline std::tuple Tensor::eig(bool eigenvectors) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors)"); -- return table->getOp (const Tensor &, bool)>(type_set())(const_cast(*this), eigenvectors); -+ return table->getOp (const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), eigenvectors); - #endif - } - inline std::tuple Tensor::svd(bool some, bool compute_uv) const { -@@ -4512,7 +4588,7 @@ inline std::tuple Tensor::svd(bool some, bool compute_uv) - return TypeDefault::svd(const_cast(*this), some, compute_uv); - #else - static auto table = globalATenDispatch().getOpTable("aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)"); -- return table->getOp (const Tensor &, bool, bool)>(type_set())(const_cast(*this), some, compute_uv); -+ return table->getOp (const Tensor &, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), some, compute_uv); - #endif - } - inline Tensor Tensor::cholesky(bool upper) const { -@@ -4520,7 +4596,7 @@ inline Tensor Tensor::cholesky(bool upper) const { - return TypeDefault::cholesky(const_cast(*this), upper); - #else - static auto table = globalATenDispatch().getOpTable("aten::cholesky(Tensor self, bool upper=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), upper); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), upper); - #endif - } - inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const { -@@ -4528,7 +4604,7 @@ inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const { - return TypeDefault::cholesky_solve(const_cast(*this), input2, upper); - #else - static auto table = globalATenDispatch().getOpTable("aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), input2, upper); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, input2))(const_cast(*this), input2, upper); - #endif - } - inline std::tuple Tensor::solve(const Tensor & A) const { -@@ -4536,7 +4612,7 @@ inline std::tuple Tensor::solve(const Tensor & A) const { - return TypeDefault::solve(const_cast(*this), A); - #else - static auto table = globalATenDispatch().getOpTable("aten::solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU)"); -- return table->getOp (const Tensor &, const Tensor &)>(type_set())(const_cast(*this), A); -+ return table->getOp (const Tensor &, const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this, A))(const_cast(*this), A); - #endif - } - inline Tensor Tensor::cholesky_inverse(bool upper) const { -@@ -4550,7 +4626,7 @@ inline Tensor Tensor::cholesky_inverse(bool upper) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), upper); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), upper); - #endif - } - inline std::tuple Tensor::qr(bool some) const { -@@ -4558,7 +4634,7 @@ inline std::tuple Tensor::qr(bool some) const { - return TypeDefault::qr(const_cast(*this), some); - #else - static auto table = globalATenDispatch().getOpTable("aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)"); -- return table->getOp (const Tensor &, bool)>(type_set())(const_cast(*this), some); -+ return table->getOp (const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), some); - #endif - } - inline std::tuple Tensor::geqrf() const { -@@ -4572,7 +4648,7 @@ inline std::tuple Tensor::geqrf() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)"); -- return table->getOp (const Tensor &)>(type_set())(const_cast(*this)); -+ return table->getOp (const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::orgqr(const Tensor & input2) const { -@@ -4586,7 +4662,7 @@ inline Tensor Tensor::orgqr(const Tensor & input2) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::orgqr(Tensor self, Tensor input2) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), input2); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, input2))(const_cast(*this), input2); - #endif - } - inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const { -@@ -4600,7 +4676,7 @@ inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool l - } - #else - static auto table = globalATenDispatch().getOpTable("aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), input2, input3, left, transpose); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, input2, input3))(const_cast(*this), input2, input3, left, transpose); - #endif - } - inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const { -@@ -4608,7 +4684,7 @@ inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) - return TypeDefault::lu_solve(const_cast(*this), LU_data, LU_pivots); - #else - static auto table = globalATenDispatch().getOpTable("aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), LU_data, LU_pivots); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, LU_data, LU_pivots))(const_cast(*this), LU_data, LU_pivots); - #endif - } - inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generator * generator) const { -@@ -4622,7 +4698,7 @@ inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generat - } - #else - static auto table = globalATenDispatch().getOpTable("aten::multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), num_samples, replacement, generator); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), num_samples, replacement, generator); - #endif - } - inline Tensor Tensor::lgamma() const { -@@ -4636,7 +4712,7 @@ inline Tensor Tensor::lgamma() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lgamma(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::digamma() const { -@@ -4644,7 +4720,7 @@ inline Tensor Tensor::digamma() const { - return TypeDefault::digamma(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::digamma(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::polygamma(int64_t n) const { -@@ -4652,7 +4728,7 @@ inline Tensor Tensor::polygamma(int64_t n) const { - return TypeDefault::polygamma(n, const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::polygamma(int n, Tensor self) -> Tensor"); -- return table->getOp(type_set())(n, const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(n, const_cast(*this)); - #endif - } - inline Tensor Tensor::erfinv() const { -@@ -4660,7 +4736,7 @@ inline Tensor Tensor::erfinv() const { - return TypeDefault::erfinv(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::erfinv(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::erfinv_() const { -@@ -4668,7 +4744,7 @@ inline Tensor & Tensor::erfinv_() const { - return TypeDefault::erfinv_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::erfinv_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::sign() const { -@@ -4676,7 +4752,7 @@ inline Tensor Tensor::sign() const { - return TypeDefault::sign(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::sign(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor & Tensor::sign_() const { -@@ -4684,7 +4760,7 @@ inline Tensor & Tensor::sign_() const { - return TypeDefault::sign_(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::sign_(Tensor(a!) self) -> Tensor(a!)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::dist(const Tensor & other, Scalar p) const { -@@ -4698,7 +4774,7 @@ inline Tensor Tensor::dist(const Tensor & other, Scalar p) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other, p); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other, p); - #endif - } - inline Tensor Tensor::atan2(const Tensor & other) const { -@@ -4706,7 +4782,7 @@ inline Tensor Tensor::atan2(const Tensor & other) const { - return TypeDefault::atan2(const_cast(*this), other); - #else - static auto table = globalATenDispatch().getOpTable("aten::atan2(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::lerp(const Tensor & end, Scalar weight) const { -@@ -4720,7 +4796,7 @@ inline Tensor Tensor::lerp(const Tensor & end, Scalar weight) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), end, weight); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, end))(const_cast(*this), end, weight); - #endif - } - inline Tensor Tensor::lerp(const Tensor & end, const Tensor & weight) const { -@@ -4734,7 +4810,7 @@ inline Tensor Tensor::lerp(const Tensor & end, const Tensor & weight) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), end, weight); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, end, weight))(const_cast(*this), end, weight); - #endif - } - inline Tensor Tensor::histc(int64_t bins, Scalar min, Scalar max) const { -@@ -4748,7 +4824,7 @@ inline Tensor Tensor::histc(int64_t bins, Scalar min, Scalar max) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), bins, min, max); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), bins, min, max); - #endif - } - inline Tensor Tensor::fmod(Scalar other) const { -@@ -4762,7 +4838,7 @@ inline Tensor Tensor::fmod(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::fmod(const Tensor & other) const { -@@ -4776,7 +4852,7 @@ inline Tensor Tensor::fmod(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::remainder(Scalar other) const { -@@ -4790,7 +4866,7 @@ inline Tensor Tensor::remainder(Scalar other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::remainder(const Tensor & other) const { -@@ -4804,7 +4880,7 @@ inline Tensor Tensor::remainder(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::min(const Tensor & other) const { -@@ -4818,7 +4894,7 @@ inline Tensor Tensor::min(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::min.other(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::min() const { -@@ -4835,7 +4911,7 @@ inline Tensor Tensor::min() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::min(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::max(const Tensor & other) const { -@@ -4849,7 +4925,7 @@ inline Tensor Tensor::max(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::max.other(Tensor self, Tensor other) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::max() const { -@@ -4866,7 +4942,7 @@ inline Tensor Tensor::max() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::max(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::median() const { -@@ -4880,7 +4956,7 @@ inline Tensor Tensor::median() const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::median(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline std::tuple Tensor::sort(int64_t dim, bool descending) const { -@@ -4897,7 +4973,7 @@ inline std::tuple Tensor::sort(int64_t dim, bool descending) cons - } - #else - static auto table = globalATenDispatch().getOpTable("aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, int64_t, bool)>(type_set())(const_cast(*this), dim, descending); -+ return table->getOp (const Tensor &, int64_t, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, descending); - #endif - } - inline Tensor Tensor::argsort(int64_t dim, bool descending) const { -@@ -4905,7 +4981,7 @@ inline Tensor Tensor::argsort(int64_t dim, bool descending) const { - return TypeDefault::argsort(const_cast(*this), dim, descending); - #else - static auto table = globalATenDispatch().getOpTable("aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), dim, descending); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, descending); - #endif - } - inline std::tuple Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const { -@@ -4913,7 +4989,7 @@ inline std::tuple Tensor::topk(int64_t k, int64_t dim, bool large - return TypeDefault::topk(const_cast(*this), k, dim, largest, sorted); - #else - static auto table = globalATenDispatch().getOpTable("aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)"); -- return table->getOp (const Tensor &, int64_t, int64_t, bool, bool)>(type_set())(const_cast(*this), k, dim, largest, sorted); -+ return table->getOp (const Tensor &, int64_t, int64_t, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), k, dim, largest, sorted); - #endif - } - inline Tensor Tensor::all() const { -@@ -4921,7 +4997,7 @@ inline Tensor Tensor::all() const { - return TypeDefault::all(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::all(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::any() const { -@@ -4929,7 +5005,7 @@ inline Tensor Tensor::any() const { - return TypeDefault::any(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::any(Tensor self) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - inline Tensor Tensor::renorm(Scalar p, int64_t dim, Scalar maxnorm) const { -@@ -4943,7 +5019,7 @@ inline Tensor Tensor::renorm(Scalar p, int64_t dim, Scalar maxnorm) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), p, dim, maxnorm); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, maxnorm); - #endif - } - inline Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) const { -@@ -4957,7 +5033,7 @@ inline Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) cons - } - #else - static auto table = globalATenDispatch().getOpTable("aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this), dimension, size, step); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dimension, size, step); - #endif - } - inline bool Tensor::equal(const Tensor & other) const { -@@ -4974,7 +5050,7 @@ inline bool Tensor::equal(const Tensor & other) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::equal(Tensor self, Tensor other) -> bool"); -- return table->getOp(type_set())(const_cast(*this), other); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, other))(const_cast(*this), other); - #endif - } - inline Tensor Tensor::pow(const Tensor & exponent) const { -@@ -4988,7 +5064,7 @@ inline Tensor Tensor::pow(const Tensor & exponent) const { - } - #else - static auto table = globalATenDispatch().getOpTable("aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor"); -- return table->getOp(type_set())(const_cast(*this), exponent); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, exponent))(const_cast(*this), exponent); - #endif - } - inline Tensor Tensor::alias() const { -@@ -4996,7 +5072,7 @@ inline Tensor Tensor::alias() const { - return TypeDefault::alias(const_cast(*this)); - #else - static auto table = globalATenDispatch().getOpTable("aten::alias(Tensor(a) self) -> Tensor(a)"); -- return table->getOp(type_set())(const_cast(*this)); -+ return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); - #endif - } - -diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h -new file mode 100644 -index 0000000000..f1078ca9ce ---- /dev/null -+++ b/aten/src/ATen/core/Variadic.h -@@ -0,0 +1,74 @@ -+#pragma once -+ -+#include -+#include -+#include -+#include -+ -+namespace at { -+ -+// This class allows you to write variadic functions which -+// call a (possibly overloaded) function on each argument, -+// in order. This is most commonly used in autogenerated code, -+// where it is convenient to have a function that can uniformly -+// take arguments of different types. If your arguments -+// are homogenous consider using a std::initializer_list instead. -+// -+// For examples of this in use, see torch/csrc/utils/variadic.h -+template -+struct IterArgs { -+ template -+ inline F& apply() { -+ return self(); -+ } -+ -+ // NB: Use perfect forwarding here, otherwise we'll make value -+ // copies of all arguments! -+ template -+ inline F& apply(T&& arg, Args&&... args) { -+ self()(std::forward(arg)); -+ if (self().short_circuit()) { -+ return self(); -+ } else { -+ return apply(std::forward(args)...); -+ } -+ } -+ -+ // Here are some handy overloads which provide sensible -+ // defaults for container-like structures that one might -+ // be interested in recursing into. You can enable them -+ // by adding: -+ // -+ // using IterArgs::operator() -+ // -+ // to your struct. These are not enabled by default because -+ // you may be able to process these structures more efficiently -+ // than handling them one-by-one. -+ -+ template -+ void operator()(at::ArrayRef args) { -+ for (const auto& arg : args) { -+ self()(arg); -+ if (short_circuit()) -+ return; -+ } -+ } -+ -+ // NB: we need to specify std::vector manually as C++ won't -+ // do an implicit conversion to make a template deduction go through. -+ template -+ void operator()(const std::vector& args) { -+ self()(at::ArrayRef{args}); -+ } -+ -+ bool short_circuit() { -+ return false; -+ } -+ -+ private: -+ inline F& self() { -+ return *static_cast(this); -+ } -+}; -+ -+} // namespace torch -diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h -index e1feb58b80..dd07afe0e6 100644 ---- a/aten/src/ATen/core/aten_interned_strings.h -+++ b/aten/src/ATen/core/aten_interned_strings.h -@@ -579,7 +579,6 @@ _(aten, rrelu_with_noise) \ - _(aten, rrelu_with_noise_backward) \ - _(aten, rrelu_with_noise_forward) \ - _(aten, rsqrt) \ --_(aten, s_native_addmm) \ - _(aten, scatter) \ - _(aten, scatter_add) \ - _(aten, select) \ -diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py -index 633c0ed1d9..597f0aad97 100644 ---- a/aten/src/ATen/function_wrapper.py -+++ b/aten/src/ATen/function_wrapper.py -@@ -129,14 +129,13 @@ TENSOR_METHOD_DECLARATION = CodeTemplate("""\ - ${return_type} ${api_name}(${method_formals_with_defaults}) const; - """) - # add non-virtual declaration to Tensor.cpp --# TODO: This will need to be adjusted for multiple dispatch - TENSOR_METHOD_DEFINITION = CodeTemplate("""\ - inline ${return_type} Tensor::${api_name}(${method_formals}) const { - #ifdef USE_STATIC_DISPATCH - ${static_dispatch_method_body} - #else - static auto table = globalATenDispatch().getOpTable("${schema_string}"); -- return table->getOp<${return_type} (${formals_types})>(type_set())(${method_actuals}); -+ return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${method_actuals}); - #endif - } - """) -@@ -835,6 +834,19 @@ def create_generic(top_env, declarations): - - return None - -+ def find_multidispatch_tensors(formals): -+ # type: (List[AtFormal]) -> List[str] -+ # Compute the list of all tensor arguments which should be considered -+ # for multiple dispatch. Note that this doesn't completely replace -+ # find_dispatch_tensor because we use the "dispatch tensor" to determine -+ # device guards. This is ONLY used for multiple dispatch in -+ # ATenDispatch.h -+ r = [] -+ for formal in formals: -+ if 'TensorList' == formal['dynamic_type'] or is_any_tensor_type(formal): -+ r.append(formal['name']) -+ return r -+ - def format_formal(f): - # type: (AtFormal) -> str - return '{} {}'.format(f['type'], f['name']) -@@ -903,6 +915,8 @@ def create_generic(top_env, declarations): - - def process_option(option): - # type: (FunctionOption) -> None -+ # Mutably populate option with derived values computed from values -+ # passed in to option. - option['inplace'] = re.search( - '(^__i|[^_]_$)', option['api_name']) is not None - -@@ -1090,8 +1104,23 @@ def create_generic(top_env, declarations): - def has_named_tensor_formals(formals): - return any(['Dimname' in formal['dynamic_type'] for formal in formals]) - -- def gen_tensor_method(option): -- # type: (Any) -> FunctionCode -+ def gen_tensor_method(option, multidispatch_tensors): -+ # type: (Any, Optional[List[str]]) -> FunctionCode -+ # TODO: Swing this shared code to top level -+ if multidispatch_tensors: -+ def swizzle_self(t): # blegh -+ if t == 'self': -+ return '*this' -+ else: -+ return t -+ option['inferred_type_set'] = 'at::detail::multi_dispatch_tensor_type_set({})'.format( -+ ', '.join(swizzle_self(t) for t in multidispatch_tensors) -+ ) -+ else: -+ # TODO: Err, what? If we didn't trigger multidispatch_tensors -+ # codepath... how?! This is a method, surely something must be -+ # dispatching! -+ option['inferred_type_set'] = 'type_set(/* HMMMM */)' - if isinstance(type_method_dispatch, dict): - static_dispatch_function_switches = [] - # NB: As this code is currently written, there will NEVER be -@@ -1125,10 +1154,11 @@ def create_generic(top_env, declarations): - declaration=TENSOR_METHOD_DECLARATION.substitute(option, static_dispatch_method_body=static_dispatch_method_body), - definition=TENSOR_METHOD_DEFINITION.substitute(option, static_dispatch_method_body=static_dispatch_method_body)) - -- def gen_namespace_function(option, dispatch_tensor, dispatch_options): -- # type: (Any, Optional[str], Any) -> FunctionCode -- if dispatch_tensor: -- option['inferred_type_set'] = 'at::detail::infer_tensor_type_set({})'.format(dispatch_tensor) -+ def gen_namespace_function(option, multidispatch_tensors, dispatch_options): -+ # type: (Any, Optional[List[str]], Any) -> FunctionCode -+ if multidispatch_tensors: -+ option['inferred_type_set'] = ( -+ 'at::detail::multi_dispatch_tensor_type_set({})'.format(', '.join(multidispatch_tensors))) - elif dispatch_options: - option['inferred_type_set'] = '{}.type_set()'.format(dispatch_options['name']) - else: -@@ -1190,6 +1220,9 @@ def create_generic(top_env, declarations): - # Only dispatch via tensor if there is no Options argument - dispatch_tensor = None if dispatch_options else find_dispatch_tensor(formals) - -+ # TODO: Not entirely clear what to do about TensorOptions -+ multidispatch_tensors = None if dispatch_options else find_multidispatch_tensors(formals) -+ - option['type_method_formals'] = [format_formal(f) for f in formals] - option['type_method_actuals'] = [f['name'] for f in formals] - option['native_actuals'] = [f['name'] for f in formals] -@@ -1257,7 +1290,7 @@ def create_generic(top_env, declarations): - - method_of = ['Type'] - if is_method: -- code = gen_tensor_method(option) -+ code = gen_tensor_method(option, multidispatch_tensors) - if is_named_tensor_only: - code = add_namedtensor_enabled_macro(code) - top_env['tensor_method_declarations'].append(code.declaration) -@@ -1265,7 +1298,7 @@ def create_generic(top_env, declarations): - method_of.append('Tensor') - - if is_namespace_function: -- code = gen_namespace_function(option, dispatch_tensor, dispatch_options) -+ code = gen_namespace_function(option, multidispatch_tensors, dispatch_options) - if is_named_tensor_only: - code = add_namedtensor_enabled_macro(code) - top_env['function_definitions'].append(code.definition) -@@ -1304,7 +1337,7 @@ def create_generic(top_env, declarations): - option["schema_string"] = declaration["schema_string"] - try: - if option['mode'] != 'native': -- # XXX: Does the following line do anything meaningful? -+ # Mutably populate option with values - process_option(option) - else: - output_option = process_native(option) -diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp -index 680cbcd9c7..c6ece4dabc 100644 ---- a/aten/src/ATen/native/BinaryOps.cpp -+++ b/aten/src/ATen/native/BinaryOps.cpp -@@ -20,16 +20,6 @@ static constexpr char alpha_mismatch_err[] = - "For integral input tensors, argument alpha must not be a floating point number."; - - Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { -- if (other.is_sparse()) { -- if (self.is_sparse()) { -- at::_sparse_add_out(result, self, other, alpha); -- } else { -- at::_sparse_dense_add_out(result, self, other, alpha); -- } -- return result; -- } else if (self.is_sparse()) { -- AT_ERROR("add(sparse, dense) is not supported. Use add(dense, sparse) instead."); -- } - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); - TORCH_CHECK(! alpha.isBoolean() || iter.dtype() == ScalarType::Bool, "Boolean alpha only supported for boolean results"); -@@ -41,10 +31,6 @@ Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar - - Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { - Tensor result; -- if (other.is_sparse()) { -- result = at::empty({0}, self.options()); -- return native::add_out(result, self, other, alpha); -- } - auto iter = TensorIterator::binary_op(result, self, other); - TORCH_CHECK(! alpha.isBoolean() || iter.dtype() == ScalarType::Bool, "Boolean alpha only supported for boolean results"); - TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(true), alpha_mismatch_err); -@@ -57,13 +43,6 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { - } - - Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { -- if (self.is_sparse()) { -- if (other.dim() != 0) { -- AT_ERROR("div(): sparse division only supports division by a scalar ", -- "(got shape ", other.sizes(), " for argument 'other')"); -- } -- return at::_sparse_div_zerodim_out(result, self, other); -- } - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); - div_stub(iter.device_type(), iter); -@@ -72,10 +51,6 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { - - Tensor div(const Tensor& self, const Tensor& other) { - Tensor result; -- if (self.is_sparse()) { -- result = at::empty({0}, self.options()); -- return native::div_out(result, self, other); -- } - auto iter = TensorIterator::binary_op(result, self, other); - div_stub(iter.device_type(), iter); - return iter.output(); -@@ -86,9 +61,6 @@ Tensor& div_(Tensor& self, const Tensor& other) { - } - - Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { -- if (self.is_sparse() || other.is_sparse()) { -- return at::_sparse_mul_out(result, self, other); -- } - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); - mul_stub(iter.device_type(), iter); -@@ -97,10 +69,6 @@ Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { - - Tensor mul(const Tensor& self, const Tensor& other) { - Tensor result; -- if (self.is_sparse() || other.is_sparse()) { -- result = at::empty({0}, self.options()); -- return native::mul_out(result, self, other); -- } - auto iter = TensorIterator::binary_op(result, self, other); - mul_stub(iter.device_type(), iter); - return iter.output(); -@@ -122,19 +90,6 @@ static inline void sub_check(const Tensor& self, const Tensor& other) { - - Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { - sub_check(self, other); -- if (other.is_sparse()) { -- if (!self.sizes().equals(other.sizes())) { -- AT_ERROR("sizes do not match"); -- } -- if (self.is_sparse()) { -- at::_sparse_add_out(result, self, other, -alpha); -- } else { -- at::_sparse_dense_add_out(result, self, other, -alpha); -- } -- return result; -- } else if (self.is_sparse()) { -- AT_ERROR("sub(sparse, dense) is not supported. Use sub(dense, sparse) instead."); -- } - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); - TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(false), alpha_mismatch_err); -@@ -146,10 +101,6 @@ Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar - Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) { - sub_check(self, other); - Tensor result; -- if (other.is_sparse()) { -- result = at::empty({0}, self.options()); -- return native::sub_out(result, self, other, alpha); -- } - auto iter = TensorIterator::binary_op(result, self, other); - TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(false), alpha_mismatch_err); - sub_stub(iter.device_type(), iter, alpha); -@@ -197,12 +148,18 @@ Tensor& add_(Tensor& self, Scalar other, Scalar alpha) { - return native::add_(self, wrapped_scalar_tensor(other), alpha); - } - -+// WARNING: There doesn't appear to be any testing for this function -+// with sparse self input. - Tensor div(const Tensor& self, Scalar other) { -- return native::div(self, wrapped_scalar_tensor(other)); -+ return self.div(wrapped_scalar_tensor(other)); // redispatch! - } - -+// WARNING: This function, with a sparse self, is currently only -+// exercised by DistributedDataParallelTest.test_sparse_gradients -+// (you need to exercise it from C++, because this overload is never -+// used for Python) - Tensor& div_(Tensor& self, Scalar other) { -- return native::div_(self, wrapped_scalar_tensor(other)); -+ return self.div_(wrapped_scalar_tensor(other)); // redispatch! - } - - Tensor mul(const Tensor& self, Scalar other) { -diff --git a/aten/src/ATen/native/LegacyBridge.cpp b/aten/src/ATen/native/LegacyBridge.cpp -index 3e544cb14d..e69de29bb2 100644 ---- a/aten/src/ATen/native/LegacyBridge.cpp -+++ b/aten/src/ATen/native/LegacyBridge.cpp -@@ -1,79 +0,0 @@ --#include --#include --#include -- --namespace at { namespace native { -- --// Note [Multiple dispatch to sparse] --// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ --// In an ideal world, we would use direct support for multiple dispatch to --// say that add(Dense, Dense) should dispatch to one function, while --// add(Dense, Sparse) should dispatch to another function. --// --// In a world where we only have single dispatch, we can single dispatch on --// the first function, and then do an is_sparse() test on the second argument --// to direct ourselves to the correct argument. --// --// We are in neither of those worlds. Instead, we have a _th_addmm function --// which has legacy implementations in the single dispatch world, BUT our --// actual addmm function needs to call s_native_addmm if the function *would have* --// utilized a sparse kernel that is natively implemented. --// --// _th_addmm is "good old single dispatch" which internally handles the is_sparse() --// test and also handles broadcasting. s_native_addmm works asymmetrically: --// it doesn't handle broadcasting at all, and it ASSUMES that the relevant --// argument is a sparse tensor. Why the asymmetry? It turns out it is not --// so easy to figure out if a kernel is implemented in THS; it's not as simple --// as testing if the first argument is sparse, because, e.g., --// in addmm(Dense, Sparse), the sparse kernel is in the second argument. So, --// the trampoline function is going to know about the overloads *anyway*; it --// might as well also handle is_sparse() and broadcasting while it's at it. --// --// Why not change TH to follow this new scheme? We could... but since it's --// all going away when we finish porting the TH functions to ATen, we haven't --// done it. -- --// NB: You may be tempted to implement addmm and addmm_ just as calls to addmm_out, but --// calling the actual implementing function matters, because broadcast --// will be handled differently depending on if you call addmm_ or (a seemingly --// equivalent) add_out. Arguably this mismatch in treatment is a bug, --// c.f., https://github.com/pytorch/pytorch/issues/8308 but fixing this --// bug would involve changing a lot of other places, so we leave it --// alone for now. -- --Tensor& addmm_out(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { -- // See Note [Multiple dispatch to sparse] -- auto mat1_sparse = mat1.is_sparse(); -- if (mat1_sparse) { -- Tensor b_self; -- std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); -- return s_native_addmm_out(result, b_self, mat1, mat2, beta, alpha); -- } else { -- return at::_addmm_out(result, self, mat1, mat2, beta, alpha); -- } --} -- --Tensor addmm(const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { -- // See Note [Multiple dispatch to sparse] -- auto mat1_sparse = mat1.is_sparse(); -- if (mat1_sparse) { -- Tensor b_self; -- std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm"); -- return s_native_addmm(b_self, mat1, mat2, beta, alpha); -- } else { -- return at::_addmm(self, mat1, mat2, beta, alpha); -- } --} -- --Tensor& addmm_(Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { -- // See Note [Multiple dispatch to sparse] -- auto mat1_sparse = mat1.is_sparse(); -- if (mat1_sparse) { -- // inplace is not broadcasting -- return s_native_addmm_(self, mat1, mat2, beta, alpha); -- } else { -- return at::_addmm_(self, mat1, mat2, beta, alpha); -- } --} -- --}} // namespace at::native -diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml -index a7b5981144..18b2c97e8c 100644 ---- a/aten/src/ATen/native/native_functions.yaml -+++ b/aten/src/ATen/native/native_functions.yaml -@@ -171,8 +171,8 @@ - dispatch: - CPU: add - CUDA: add -- SparseCPU: add -- SparseCUDA: add -+ SparseCPU: add_sparse -+ SparseCUDA: add_sparse - MkldnnCPU: mkldnn_add - named_guard: False - -@@ -181,8 +181,8 @@ - dispatch: - CPU: add_ - CUDA: add_ -- SparseCPU: add_ -- SparseCUDA: add_ -+ SparseCPU: add_sparse_ -+ SparseCUDA: add_sparse_ - MkldnnCPU: mkldnn_add_ - named_guard: False - -@@ -190,8 +190,8 @@ - dispatch: - CPU: add_out - CUDA: add_out -- SparseCPU: add_out -- SparseCUDA: add_out -+ SparseCPU: add_out_sparse_cpu -+ SparseCUDA: add_out_sparse_cuda - MkldnnCPU: mkldnn_add_out - named_guard: False - -@@ -734,13 +734,28 @@ - - - func: div.Tensor(Tensor self, Tensor other) -> Tensor - variants: function, method -+ dispatch: -+ CPU: div -+ CUDA: div -+ SparseCPU: div_sparse -+ SparseCUDA: div_sparse - named_guard: False - - - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - variants: method -+ dispatch: -+ CPU: div_ -+ CUDA: div_ -+ SparseCPU: div_sparse_ -+ SparseCUDA: div_sparse_ - named_guard: False - - - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) -+ dispatch: -+ CPU: div_out -+ CUDA: div_out -+ SparseCPU: div_out_sparse_zerodim -+ SparseCUDA: div_out_sparse_zerodim - named_guard: False - - # For C++ only, until we have conversion from C++ numbers to Tensor -@@ -1555,19 +1570,18 @@ - dispatch: - CPU: mul - CUDA: mul -- SparseCPU: mul -- SparseCUDA: mul -+ SparseCPU: mul_sparse -+ SparseCUDA: mul_sparse - MkldnnCPU: mkldnn_mul - named_guard: False - -- - - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - variants: method - dispatch: - CPU: mul_ - CUDA: mul_ -- SparseCPU: mul_ -- SparseCUDA: mul_ -+ SparseCPU: mul_sparse_ -+ SparseCUDA: mul_sparse_ - MkldnnCPU: mkldnn_mul_ - named_guard: False - -@@ -1575,8 +1589,8 @@ - dispatch: - CPU: mul_out - CUDA: mul_out -- SparseCPU: mul_out -- SparseCUDA: mul_out -+ SparseCPU: mul_out_sparse_cpu -+ SparseCUDA: mul_out_sparse_cuda - MkldnnCPU: mkldnn_mul_out - named_guard: False - -@@ -2085,41 +2099,6 @@ - CPU: softmax_backward_cpu - CUDA: softmax_backward_cuda - --- func: _sparse_add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- SparseCPU: add_out_sparse_cpu -- SparseCUDA: add_out_sparse_cuda -- --- func: _sparse_dense_add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- CPU: add_out_dense_sparse_cpu -- CUDA: add_out_dense_sparse_cuda -- --- func: _sparse_div_zerodim.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- SparseCPU: div_out_sparse_zerodim -- SparseCUDA: div_out_sparse_zerodim -- --- func: _sparse_div_scalar.out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- SparseCPU: div_out_sparse_scalar -- SparseCUDA: div_out_sparse_scalar -- --- func: _sparse_mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- SparseCPU: mul_out_sparse_cpu -- SparseCUDA: mul_out_sparse_cuda -- --- func: _sparse_mul_zerodim.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- SparseCPU: mul_out_sparse_zerodim -- SparseCUDA: mul_out_sparse_zerodim -- --- func: _sparse_mul_scalar.out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- SparseCPU: mul_out_sparse_scalar -- SparseCUDA: mul_out_sparse_scalar -- - - func: split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[] - variants: function, method - device_guard: False -@@ -2671,14 +2650,29 @@ - MkldnnCPU: mkldnn_zero_ - - - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) -+ dispatch: -+ CPU: sub_out -+ CUDA: sub_out -+ SparseCPU: sub_out_sparse -+ SparseCUDA: sub_out_sparse - named_guard: False - - - func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - variants: function, method -+ dispatch: -+ CPU: sub -+ CUDA: sub -+ SparseCPU: sub_sparse -+ SparseCUDA: sub_sparse - named_guard: False - - - func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) - variants: method -+ dispatch: -+ CPU: sub_ -+ CUDA: sub_ -+ SparseCPU: sub_sparse_ -+ SparseCUDA: sub_sparse_ - named_guard: False - - # For C++ only, until we have conversion from C++ numbers to Tensor -@@ -2699,32 +2693,37 @@ - variants: function - named_guard: False - --- func: s_native_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- CPU: s_addmm_out_sparse_dense_cpu -- CUDA: s_addmm_out_sparse_dense_cuda -- --- func: s_native_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor -- dispatch: -- CPU: s_addmm_sparse_dense_cpu -- CUDA: s_addmm_sparse_dense_cuda -- --- func: s_native_addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) -- dispatch: -- CPU: s_addmm_sparse_dense_cpu_ -- CUDA: s_addmm_sparse_dense_cuda_ -- -+# Functionally the same as addmm, but we give it a different derivative formula -+# that doesn't propagate gradients to non-present entries on sparse. - - func: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor -+ named_guard: False - - - func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) -+ dispatch: -+ CPU: legacy::cpu::_th_addmm_out -+ CUDA: legacy::cuda::_th_addmm_out -+ SparseCPU: addmm_out_sparse_dense_cpu -+ SparseCUDA: addmm_out_sparse_dense_cuda - named_guard: False - - - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - variants: function, method -+ dispatch: -+ CPU: legacy::cpu::_th_addmm -+ CUDA: legacy::cuda::_th_addmm -+ SparseCPU: addmm_sparse_dense_cpu -+ SparseCUDA: addmm_sparse_dense_cuda - named_guard: False - - - func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - variants: method -+ dispatch: -+ CPU: legacy::cpu::_th_addmm_ -+ CUDA: legacy::cuda::_th_addmm_ -+ # Warning! For whatever reason, the inplace sparse addmm is NON -+ # broadcasting -+ SparseCPU: s_addmm_sparse_dense_cpu_ -+ SparseCUDA: s_addmm_sparse_dense_cuda_ - named_guard: False - - -@@ -2884,8 +2883,8 @@ - - func: sparse_mask(Tensor self, Tensor mask) -> Tensor - variants: method - dispatch: -- CPU: sparse_mask_cpu -- CUDA: sparse_mask_cuda -+ SparseCPU: sparse_mask_cpu -+ SparseCUDA: sparse_mask_cuda - requires_tensor: True - - -@@ -4556,24 +4555,6 @@ - CUDA: legacy::cuda::_th_std - named_guard: False - --- func: _addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) -- dispatch: -- CPU: legacy::cpu::_th_addmm_out -- CUDA: legacy::cuda::_th_addmm_out -- named_guard: False -- --- func: _addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor -- dispatch: -- CPU: legacy::cpu::_th_addmm -- CUDA: legacy::cuda::_th_addmm -- named_guard: False -- --- func: _addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) -- dispatch: -- CPU: legacy::cpu::_th_addmm_ -- CUDA: legacy::cuda::_th_addmm_ -- named_guard: False -- - - func: _cat(Tensor[] tensors, int dim=0) -> Tensor - dispatch: - CPU: legacy::cpu::_th_cat -diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp -index f9cbe6c96c..b01166245a 100644 ---- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp -+++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp -@@ -1,3 +1,5 @@ -+#include -+ - #include - #include - #include -@@ -148,10 +150,23 @@ SparseTensor pow_sparse_scalar(const SparseTensor& t, Scalar value) { - // div(SparseTensor, Scalar) - // -------------------------------------------------------------------- - -+SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value); -+ -+Tensor div_sparse(const Tensor& self, const Tensor& value) { -+ Tensor result = at::empty({0}, self.options()); -+ return div_out_sparse_zerodim(result, self, value); -+} -+ -+Tensor& div_sparse_(Tensor& self, const Tensor& value) { -+ return div_out_sparse_zerodim(self, self, value); -+} -+ - SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) { -+ TORCH_CHECK(value.dim() == 0, "sparse division only supports division by a scalar (got shape ", -+ value.sizes(), " for argument 'other')"); -+ - AT_ASSERT(r.is_sparse()); - AT_ASSERT(t.is_sparse()); -- AT_ASSERT(value.dim() == 0); - - if (is_same_tensor(r, t)) { - r._values().div_(value); -@@ -187,9 +202,40 @@ Tensor norm_sparse(const SparseTensor& self, Scalar value) { - // add(SparseTensor, SparseTensor, Scalar) [broadcasts] - // -------------------------------------------------------------------- - -+Tensor add_sparse(const Tensor& self, const Tensor& other, Scalar alpha) { -+ // TODO: Why?! Can't we just flip the order here... -+ TORCH_CHECK(!(self.is_sparse() && !other.is_sparse()), -+ "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); -+ Tensor result = at::empty({0}, self.options()); -+ return at::add_out(result, self, other, alpha); // redispatch! -+} -+ -+Tensor& add_sparse_(Tensor& self, const Tensor& other, Scalar alpha) { -+ return at::add_out(self, self, other, alpha); // redispatch! -+} -+ -+// There's actually nothing sparse specific about these implementations -+ -+Tensor sub_sparse(const Tensor& self, const Tensor& other, Scalar alpha) { -+ return native::add_sparse(self, other, -alpha); -+} -+ -+Tensor& sub_sparse_(Tensor& self, const Tensor& other, Scalar alpha) { -+ return native::add_sparse_(self, other, -alpha); -+} -+ -+Tensor& sub_out_sparse(Tensor& r, const Tensor& self, const Tensor& other, Scalar alpha) { -+ return at::add_out(r, self, other, -alpha); // redispatch! -+} -+ -+Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, Scalar value); -+ - SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, Scalar value) { -- AT_ASSERT(r.is_sparse()); -- AT_ASSERT(t.is_sparse()); -+ if (!t.is_sparse()) { -+ return add_out_dense_sparse_cpu(r, t, src, value); -+ } -+ // TODO: This test seems a bit goofy -+ TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); - AT_ASSERT(!t.is_cuda()); // the dispatch argument - TORCH_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor"); - TORCH_CHECK(!src.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor"); -@@ -375,6 +421,15 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTen - // mul(SparseTensor, SparseTensor) [broadcasts] - // -------------------------------------------------------------------- - -+Tensor mul_sparse(const Tensor& self, const Tensor& other) { -+ Tensor result = at::empty({0}, self.options()); -+ return at::mul_out(result, self, other); // redispatch! -+} -+ -+Tensor& mul_sparse_(Tensor& self, const Tensor& other) { -+ return at::mul_out(self, self, other); // redispatch! -+} -+ - SparseTensor& mul_out_sparse_cpu(SparseTensor& r, const Tensor& t_, const Tensor& src_) { - if (src_.dim() == 0) { - return mul_out_sparse_zerodim(r, t_, src_); -@@ -576,6 +631,19 @@ Tensor& s_addmm_out_sparse_dense_cpu( - - } - -+Tensor& addmm_out_sparse_dense_cpu( -+ Tensor& result, -+ const Tensor& self, -+ const SparseTensor& mat1, -+ const Tensor& mat2, -+ Scalar beta, -+ Scalar alpha -+) { -+ Tensor b_self; -+ std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); -+ return s_addmm_out_sparse_dense_cpu(result, b_self, mat1, mat2, beta, alpha); -+} -+ - Tensor s_addmm_sparse_dense_cpu( - const Tensor& t, - const SparseTensor& sparse, -@@ -588,6 +656,18 @@ Tensor s_addmm_sparse_dense_cpu( - return r; - } - -+Tensor addmm_sparse_dense_cpu( -+ const Tensor& self, -+ const SparseTensor& mat1, -+ const Tensor& mat2, -+ Scalar beta, -+ Scalar alpha -+) { -+ Tensor b_self; -+ std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); -+ return s_addmm_sparse_dense_cpu(b_self, mat1, mat2, beta, alpha); -+} -+ - Tensor& s_addmm_sparse_dense_cpu_( - Tensor& t, - const SparseTensor& sparse, -@@ -598,6 +678,8 @@ Tensor& s_addmm_sparse_dense_cpu_( - return s_addmm_out_sparse_dense_cpu(t, t, sparse, dense, beta, alpha); - } - -+// NB: Purposely no broadcasting version of addmm inplace -+ - Tensor _sparse_addmm( - const Tensor& t, - const SparseTensor& sparse, -@@ -605,9 +687,10 @@ Tensor _sparse_addmm( - Scalar beta, - Scalar alpha - ) { -- Tensor b_t; -- std::tie(b_t) = expand_size(t, {sparse.size(0), dense.size(1)}, "addmm"); -- return at::s_native_addmm(b_t, sparse, dense, beta, alpha); -+ // _sparse_addmm forward is functionally equivalent to addmm; it's -+ // just the backward that is different. This technically does an -+ // unnecessary redispatch, I was too lazy to make it not do that -+ return at::addmm(t, sparse, dense, beta, alpha); - } - - Tensor _sparse_mm( -@@ -615,16 +698,19 @@ Tensor _sparse_mm( - const Tensor& dense - ) { - Tensor t = at::zeros({}, dense.options()); -- return at::_sparse_addmm(t, sparse, dense, 0, 1); -+ return at::_sparse_addmm(t, sparse, dense, 0, 1); // redispatch! - } - -+// NB: Despite its suggestive name, this actually only exists so that -+// we can redispatch to addmm_out; this is NOT an implementation of -+// the sparse masking version of mm - SparseTensor& _sparse_mm_out( - SparseTensor& result, - const SparseTensor& sparse, - const Tensor& dense - ) { - Tensor t = at::zeros({}, dense.options()); -- return at::addmm_out(result, t, sparse, dense, 0, 1); -+ return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch! - } - - // -------------------------------------------------------------------- -diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.h b/aten/src/ATen/native/sparse/SparseTensorMath.h -new file mode 100644 -index 0000000000..514f84fd8e ---- /dev/null -+++ b/aten/src/ATen/native/sparse/SparseTensorMath.h -@@ -0,0 +1,11 @@ -+#pragma once -+ -+#include -+#include -+ -+namespace at { namespace native { -+ -+sparse::SparseTensor& mul_out_sparse_scalar(sparse::SparseTensor& r, const sparse::SparseTensor& t, Scalar value); -+sparse::SparseTensor& mul_out_sparse_zerodim(sparse::SparseTensor& r, const sparse::SparseTensor& t, const Tensor& value); -+ -+}} -diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp -index 927d91797d..4d43150547 100644 ---- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp -+++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp -@@ -12,7 +12,7 @@ SparseTensor& sparse_mask_out_cuda(SparseTensor& r, const Tensor& t, const Spars - TORCH_CHECK(mask.is_coalesced(), "sparse_mask: mask is uncoalesced"); - TORCH_CHECK(mask.sizes().equals(t.sizes()), "sparse_mask: operands have incompatible sizes; self has size ", - t.sizes(), " but mask has size ", mask.sizes()); -- AT_ASSERT(t.is_cuda()); // dispatch argument -+ TORCH_CHECK(t.is_cuda(), "sparse_mask: expected 'self' to be CUDA, but got CPU"); - TORCH_CHECK(mask.is_cuda(), "sparse_mask: expected 'mask' to be CUDA, but got CPU"); - TORCH_CHECK(r.is_cuda(), "sparse_mask: expected 'out' to be CUDA, but got CPU"); - TORCH_CHECK(cuda::check_device({r, t, mask}), -diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu -index fdc17e4dfd..f681f16ce4 100644 ---- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu -+++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu -@@ -2,12 +2,14 @@ - #include - #include - #include -+#include - #include - #include - #include - #include - #include - #include -+#include - - #include - #include -@@ -51,7 +53,7 @@ namespace { - // -------------------------------------------------------------------- - - Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseTensor& sparse_, const Tensor& dense, Scalar beta, Scalar alpha) { -- AT_ASSERT(t.is_cuda()); // dispatch argument -+ TORCH_CHECK(t.is_cuda(), "addmm: expected 'self' to be CUDA, but got CPU"); - TORCH_CHECK(r_.is_cuda(), "addmm: expected 'out' to be CUDA, but got CPU"); - TORCH_CHECK(sparse_.is_cuda(), "addmm: expected 'mat1' to be CUDA, but got CPU"); - TORCH_CHECK(dense.is_cuda(), "addmm: expected 'mat2' to be CUDA, but got CPU"); -@@ -151,6 +153,19 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT - return r_; - } - -+Tensor& addmm_out_sparse_dense_cuda( -+ Tensor& result, -+ const Tensor& self, -+ const SparseTensor& mat1, -+ const Tensor& mat2, -+ Scalar beta, -+ Scalar alpha -+) { -+ Tensor b_self; -+ std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); -+ return s_addmm_out_sparse_dense_cuda(result, b_self, mat1, mat2, beta, alpha); -+} -+ - Tensor s_addmm_sparse_dense_cuda( - const Tensor& t, - const SparseTensor& sparse, -@@ -163,6 +178,18 @@ Tensor s_addmm_sparse_dense_cuda( - return r; - } - -+Tensor addmm_sparse_dense_cuda( -+ const Tensor& self, -+ const SparseTensor& mat1, -+ const Tensor& mat2, -+ Scalar beta, -+ Scalar alpha -+) { -+ Tensor b_self; -+ std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); -+ return s_addmm_sparse_dense_cuda(b_self, mat1, mat2, beta, alpha); -+} -+ - Tensor& s_addmm_sparse_dense_cuda_( - Tensor& t, - const SparseTensor& sparse, -@@ -173,6 +200,8 @@ Tensor& s_addmm_sparse_dense_cuda_( - return s_addmm_out_sparse_dense_cuda(t, t, sparse, dense, beta, alpha); - } - -+// NB: Purposely no broadcasting version of addmm inplace -+ - // Deleted sspaddmm (sparse, dense) -> sparse - - // -------------------------------------------------------------------- -@@ -180,7 +209,7 @@ Tensor& s_addmm_sparse_dense_cuda_( - // -------------------------------------------------------------------- - - SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse_, const Tensor& dense/* , Scalar alpha */) { -- AT_ASSERT(sparse_.is_cuda()); // dispatch argument -+ TORCH_CHECK(sparse_.is_cuda(), "hspmm: expected 'self' to be CUDA, but got CPU"); - TORCH_CHECK(r_.is_cuda(), "hspmm: expected 'out' to be CUDA, but got CPU"); - TORCH_CHECK(dense.is_cuda(), "hspmm: expected 'mat2' to be CUDA, but got CPU"); - -@@ -249,9 +278,9 @@ SparseTensor hspmm_sparse_cuda(const SparseTensor& sparse, const Tensor& dense) - // -------------------------------------------------------------------- - - Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseTensor& sparse, at::Scalar value) { -- AT_ASSERT(dense.is_cuda()); // dispatch argument -- TORCH_CHECK(sparse.is_cuda(), "add: expected 'other' to be CUDA, but got CPU"); -- TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU"); -+ TORCH_CHECK(dense.is_cuda(), "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"); -+ TORCH_CHECK(sparse.is_cuda(), "add: expected 'other' to be a CUDA tensor, but got a CPU tensor"); -+ TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be a CUDA tensor, but got a CPU tensor"); - - TORCH_CHECK(cuda::check_device({sparse, r_, dense})); - -@@ -350,8 +379,17 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT - // add(SparseTensor, SparseTensor, Scalar) [broadcasts] - // -------------------------------------------------------------------- - -+Tensor& add_out_dense_sparse_cuda(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, Scalar value); -+ - SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const SparseTensor& src, Scalar value) { -- AT_ASSERT(t.is_cuda()); // dispatch argument -+ if (!t.is_sparse()) { -+ return add_out_dense_sparse_cuda(r_, t, src, value); -+ } -+ -+ // TODO: This test seems a bit goofy -+ TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); -+ -+ TORCH_CHECK(t.is_cuda(), "add: expected 'self' to be CUDA, but got CPU"); - TORCH_CHECK(src.is_cuda(), "add: expected 'other' to be CUDA, but got CPU"); - TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU"); - -@@ -410,7 +448,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons - return mul_out_sparse_zerodim(r_, src_, t_); - } - -- AT_ASSERT(t_.is_cuda()); // dispatch argument -+ TORCH_CHECK(t_.is_cuda(), "mul: expected 'self' to be CUDA, but got CPU"); - TORCH_CHECK(src_.is_cuda(), "mul: expected 'other' to be CUDA, but got CPU"); - TORCH_CHECK(r_.is_cuda(), "mul: expected 'out' to be CUDA, but got CPU"); - TORCH_CHECK(cuda::check_device({r_, t_, src_})); -diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h -index 03a272f12c..40c251e69d 100644 ---- a/aten/src/ATen/templates/TensorBody.h -+++ b/aten/src/ATen/templates/TensorBody.h -@@ -419,7 +419,7 @@ protected: - }; - - namespace detail { --// Helper creator for Tensor clas which doesn't requires the users to pass -+// Helper creator for Tensor class which doesn't requires the users to pass - // in an intrusive_ptr instead it just converts the argument passed to - // requested intrusive_ptr type. - template -@@ -427,15 +427,6 @@ Tensor make_tensor(Args&&... args) { - return Tensor(c10::make_intrusive(std::forward(args)...)); - } - --inline TensorTypeSet infer_tensor_type_set(const Tensor & tl) { -- return tl.type_set(); --} -- --inline TensorTypeSet infer_tensor_type_set(TensorList tl) { -- TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); -- return tl[0].type_set(); --} -- - } // namespace detail - - static inline TensorTypeId legacyExtractTypeId(const Tensor& t) { -diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h -index 0afa00c19a..c5cae10e4f 100644 ---- a/aten/src/ATen/templates/TensorMethods.h -+++ b/aten/src/ATen/templates/TensorMethods.h -@@ -11,6 +11,7 @@ - #if !defined(CAFFE2_IS_XPLAT_BUILD) - #include - #endif -+#include - #include - #ifdef USE_STATIC_DISPATCH - #include -@@ -21,6 +22,27 @@ - - namespace at { - -+namespace detail { -+ -+struct MultiDispatchTensorTypeSet : IterArgs { -+ TensorTypeSet ts; -+ void operator()(const at::Tensor& x) { -+ ts = ts | x.type_set(); -+ } -+ void operator()(at::ArrayRef xs) { -+ for (const auto& x : xs) { -+ ts = ts | x.type_set(); -+ } -+ } -+}; -+ -+template -+TensorTypeSet multi_dispatch_tensor_type_set(Args&&... args) { -+ return MultiDispatchTensorTypeSet().apply(std::forward(args)...).ts; -+} -+ -+} -+ - struct Quantizer; - // This is temporary typedef to enable Quantizer in aten native function API - // we'll remove them when we are actually exposing Quantizer class -diff --git a/c10/core/TensorTypeId.h b/c10/core/TensorTypeId.h -index eb99881296..d01ee9d5f3 100644 ---- a/c10/core/TensorTypeId.h -+++ b/c10/core/TensorTypeId.h -@@ -25,8 +25,6 @@ enum class TensorTypeId : uint8_t { - // the hierarchy for convenience and performance - CPUTensorId, // PyTorch/Caffe2 supported - CUDATensorId, // PyTorch/Caffe2 supported -- SparseCPUTensorId, // PyTorch only -- SparseCUDATensorId, // PyTorch only - MKLDNNTensorId, // Caffe2 only - OpenGLTensorId, // Caffe2 only - OpenCLTensorId, // Caffe2 only -@@ -40,6 +38,10 @@ enum class TensorTypeId : uint8_t { - ComplexCPUTensorId, // PyTorch only - ComplexCUDATensorId, // PyTorch only - -+ // Sparse has multi-dispatch with dense; handle it first -+ SparseCPUTensorId, // PyTorch only -+ SparseCUDATensorId, // PyTorch only -+ - // WARNING! If you add more "wrapper" style tensor ids (tensor - // ids which don't get kernels directly defined in native_functions.yaml; - // examples are tracing or profiling) here, you need to also adjust -diff --git a/test/test_nn.py b/test/test_nn.py -index e5d28d3253..e5d388f798 100644 ---- a/test/test_nn.py -+++ b/test/test_nn.py -@@ -1804,7 +1804,7 @@ class TestNN(NNTestCase): - # Without using `torch.no_grad()`, this will leak CUDA memory. - # (Issue is filed at https://github.com/pytorch/pytorch/issues/21875) - mw[0][0] = 5 -- with self.assertRaisesRegex(RuntimeError, "Expected object of backend CPU but got backend CUDA"): -+ with self.assertRaisesRegex(RuntimeError, "Expected object of backend CUDA but got backend CPU"): - mw[0][0] == mw._base[0][0] - - try: -@@ -2958,6 +2958,7 @@ class TestNN(NNTestCase): - x = torch.tensor([], device=device, dtype=torch.long) - for sparse in [True, False]: - Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse) -+ Embed.to(device) - - output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=torch.long)) - self.assertEqual(output, torch.zeros_like(output)) -diff --git a/test/test_sparse.py b/test/test_sparse.py -index 1243103e6e..f7795c6804 100644 ---- a/test/test_sparse.py -+++ b/test/test_sparse.py -@@ -2107,21 +2107,21 @@ class TestSparseOneOff(TestCase): - sparse_y = torch.cuda.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), - torch.randn(4, 4, 4).cuda(), - [3, 4, 4]) -- with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): -+ with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): - x + sparse_y - - x = torch.zeros(3, 4, 4, 0) - sparse_y = torch.cuda.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), - torch.randn(4, 4, 4, 0).cuda(), - [3, 4, 4, 0]) -- with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): -+ with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): - x + sparse_y - - x = torch.zeros(0, 4, 4, 0) - sparse_y = torch.cuda.sparse.FloatTensor(torch.LongTensor(1, 0).cuda(), - torch.randn(0, 4, 4, 0).cuda(), - [0, 4, 4, 0]) -- with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): -+ with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): - x + sparse_y - - -diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp -index cfc362d083..358d29dde4 100644 ---- a/tools/autograd/templates/Functions.cpp -+++ b/tools/autograd/templates/Functions.cpp -@@ -528,7 +528,7 @@ Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef si - int64_t out_cols = grad.size(1); - Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true); - Tensor r = at::empty({out_cols, out_rows}, grad.options()).t(); -- at::s_native_addmm_out(r, t, mat1.t(), grad, alpha, 1); -+ at::addmm_out(r, t, mat1.t(), grad, alpha, 1); - return r; - } - return maybe_multiply(grad.t().mm(mat1).t(), alpha); -diff --git a/torch/csrc/utils/variadic.h b/torch/csrc/utils/variadic.h -index 3a924a9db9..63f34afbc3 100644 ---- a/torch/csrc/utils/variadic.h -+++ b/torch/csrc/utils/variadic.h -@@ -2,6 +2,7 @@ - - #include - #include -+#include - - #include - #include -@@ -10,67 +11,7 @@ - - namespace torch { - --// This class allows you to write variadic functions which --// call a (possibly overloaded) function on each argument, --// in order. This is most commonly used in autogenerated code, --// where it is convenient to have a function that can uniformly --// take arguments of different types. If your arguments --// are homogenous consider using a std::initializer_list instead. --template --struct IterArgs { -- template -- inline F& apply() { -- return self(); -- } -- -- // NB: Use perfect forwarding here, otherwise we'll make value -- // copies of all arguments! -- template -- inline F& apply(T&& arg, Args&&... args) { -- self()(std::forward(arg)); -- if (self().short_circuit()) { -- return self(); -- } else { -- return apply(std::forward(args)...); -- } -- } -- -- // Here are some handy overloads which provide sensible -- // defaults for container-like structures that one might -- // be interested in recursing into. You can enable them -- // by adding: -- // -- // using IterArgs::operator() -- // -- // to your struct. These are not enabled by default because -- // you may be able to process these structures more efficiently -- // than handling them one-by-one. -- -- template -- void operator()(at::ArrayRef args) { -- for (const auto& arg : args) { -- self()(arg); -- if (short_circuit()) -- return; -- } -- } -- -- // NB: we need to specify std::vector manually as C++ won't -- // do an implicit conversion to make a template deduction go through. -- template -- void operator()(const std::vector& args) { -- self()(at::ArrayRef{args}); -- } -- -- bool short_circuit() { -- return false; -- } -- -- private: -- inline F& self() { -- return *static_cast(this); -- } --}; -+using at::IterArgs; - - struct CountTensors : IterArgs { - size_t out = 0; -@@ -194,4 +135,5 @@ template ) { - return ReturnType(function(accessor.template operator()(Is)...)); - } -+ - } // namespace torch -diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp -index 18c3ee45ef..9a032dca2a 100644 ---- a/torch/lib/c10d/ProcessGroupGloo.cpp -+++ b/torch/lib/c10d/ProcessGroupGloo.cpp -@@ -792,6 +792,14 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { - // we run allgather on the nnz, and then allgather with max(nnz). - // We could use an allgatherv for this, if it were available. - at::Tensor allreduce(std::vector& tensors) { -+ // TODO: This is a massive hack! There is some confusion about -+ // Variable/Tensor inside the body of this function. Turning off -+ // grad smooths over the confusion for now. This fixes -+ // test/test_c10d.py ProcessGroupGlooTest.test_sparse_allreduce_basics -+ // -+ // The correct fix is to stop allocating tensors that are not variables, -+ // but to conveniently do this c10d must depend on torch not ATen -+ at::AutoNonVariableTypeMode _no_grad(true); - auto input = tensors[0]; - - // Perform local reduction if we have multiple inputs. --- -2.13.5 -