diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index e3ab0a5ef328e..6ca97835b7276 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -103,7 +103,8 @@ class KernelTable_ final { class DispatchTable final { public: DispatchTable(const FunctionSchema& schema) - : kernels_(make_left()) + : kernels_() + , catchall_kernel_(c10::nullopt) , dispatch_strategy_(get_dispatch_strategy_(schema)) , operator_name_(schema.name()) {} @@ -117,8 +118,7 @@ class DispatchTable final { const DispatchTableEntry& kernel) { TORCH_INTERNAL_ASSERT(dispatch_key != TensorTypeId::UndefinedTensorId); TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments."); - TORCH_CHECK(kernels_.is_left(), "Tried to register a kernel with dispatch key ", toString(dispatch_key)," for operator ", operator_name_, ", which already has a catch-all kernel registered. An operator can only have either a catch-all kernel or kernels with dispatch keys."); - kernels_.left().set(dispatch_key, kernel, operator_name_); + kernels_.set(dispatch_key, kernel, operator_name_); } /** @@ -127,8 +127,7 @@ class DispatchTable final { * @param dispatch_key Dispatch key to unregister. */ void removeKernelIfExists(TensorTypeId dispatch_key) { - TORCH_INTERNAL_ASSERT(kernels_.is_left(), "Tried to remove the kernel for dispatch key ", toString(dispatch_key), " for operator ", operator_name_, ", which only has a catch-all kernel."); - kernels_.left().removeIfExists(dispatch_key, operator_name_); + kernels_.removeIfExists(dispatch_key, operator_name_); } /** @@ -138,20 +137,18 @@ class DispatchTable final { * dispatch keys, not both. */ void setCatchallKernel(const DispatchTableEntry& kernel) { - if (kernels_.is_right()) { + if (catchall_kernel_.has_value()) { TORCH_WARN("Registered a catch-all kernel for operator ", operator_name_," that overwrote a previously registered catch-all kernel for the same operator."); - } else { - TORCH_CHECK(0 == kernels_.left().size(), "Tried to register a catch-all kernel for operator ", operator_name_, " which already has kernels with dispatch keys. An operator can only have either a catch-all kernel or kernels with dispatch keys."); } - kernels_ = make_right(kernel); + catchall_kernel_ = kernel; } /** * Remove the catch-all kernel. */ void removeCatchallKernel() { - TORCH_INTERNAL_ASSERT(kernels_.is_right(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered."); - kernels_ = make_left(); + TORCH_INTERNAL_ASSERT(catchall_kernel_.has_value(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered."); + catchall_kernel_ = c10::nullopt; } /** @@ -162,28 +159,28 @@ class DispatchTable final { * @return Kernel function pointing to the right kernel for the given arguments. */ const DispatchTableEntry& lookup(const Stack* stack) const { - return lookup_([=] { - TORCH_INTERNAL_ASSERT(dispatch_strategy_.is_valid_, "Operator ", operator_name_, " has an invalid dispatch key but kernels registered."); + return lookup_([=] () -> c10::optional { + if (!dispatch_strategy_.is_valid_) { + return c10::nullopt; + } return dispatch_strategy_.get_dispatch_key(stack, operator_name_); }); } const DispatchTableEntry& lookup(TensorTypeId dispatchKey) const { - return lookup_([=] {return dispatchKey;}); + return lookup_([=] () -> c10::optional { return dispatchKey;}); } bool isEmpty() const { - return kernels_.map( - [] (const detail::KernelTable_& table) {return 0 == table.size();}, - [] (const DispatchTableEntry&) {return false;} - ); + return !catchall_kernel_.has_value() && kernels_.size() == 0; } std::string listAllDispatchKeys() const { - return kernels_.map( - [] (const detail::KernelTable_& table) {return table.list_all_dispatch_keys();}, - [] (const DispatchTableEntry&) {return "CATCH-ALL";} - ); + std::string result = kernels_.list_all_dispatch_keys(); + if (catchall_kernel_.has_value()) { + result += ", CATCH-ALL"; + } + return result; } private: @@ -243,30 +240,27 @@ class DispatchTable final { template const DispatchTableEntry& lookup_(const GetDispatchKeyFunc& getDispatchKey) const { - return kernels_.map( - [&] (const detail::KernelTable_& table) -> const DispatchTableEntry& { - // We have a dispatch table. Find the correct kernel for the inputs and return it. - TensorTypeId dispatch_key = getDispatchKey(); - auto found = table.lookup(dispatch_key); + c10::optional dispatch_key = getDispatchKey(); + if (dispatch_key.has_value()) { + const auto* found = kernels_.lookup(*dispatch_key); - TORCH_CHECK(nullptr != found, "Didn't find kernel to dispatch to for operator '", operator_name_, - "'. Tried to look up kernel for dispatch key '", toString(dispatch_key), - "'. Registered dispatch keys are: ", listAllDispatchKeys()); + if (nullptr != found) { + return *found; + } + } - return *found; - }, - [] (const DispatchTableEntry& entry) -> const DispatchTableEntry& { - // We have a catch-all kernel. Just return it. - return entry; + if (catchall_kernel_.has_value()) { + return *catchall_kernel_; } - ); + + const std::string dispatch_key_str = dispatch_key.has_value() ? toString(*dispatch_key) : "None"; + TORCH_CHECK(false, "Didn't find kernel to dispatch to for operator '", operator_name_, + "'. Tried to look up kernel for dispatch key '", dispatch_key_str, + "'. Registered dispatch keys are: ", listAllDispatchKeys()); } - // kernels_ either contains a dispatch table or - // a single catch-all kernel that is called for every backend - // The empty state (i.e. no kernels registered) is represented - // as an empty table. - either kernels_; + detail::KernelTable_ kernels_; + c10::optional catchall_kernel_; DispatchStrategy dispatch_strategy_; std::string operator_name_; }; diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 1fd8a5dfbc69a..f1f26b92e78c3 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -852,10 +852,6 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { c10::List(), [] (const c10::List& v) {EXPECT_EQ(0, v.size());}, c10::List(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());}, "(str[] a) -> str[]"); - testArgTypes>::test( - c10::List({}), [] (const c10::List& v) {EXPECT_EQ(0, v.size());}, - c10::List({}), [] (const IValue& v) {EXPECT_EQ(0, v.to>().size());}, - "(Tensor[] a) -> Tensor[]"); // list types (with non-empty list) @@ -906,10 +902,6 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { std::vector(), [] (const std::vector& v) {EXPECT_EQ(0, v.size());}, std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());}, "(str[] a) -> str[]"); - testArgTypes>::test( - std::vector({}), [] (const std::vector& v) {EXPECT_EQ(0, v.size());}, - std::vector({}), [] (const IValue& v) {EXPECT_EQ(0, v.to>().size());}, - "(Tensor[] a) -> Tensor[]"); // deprecated list types (with non-empty list)