Skip to content

Commit

Permalink
[WIP] Move catchAll to Math (#45939)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #45939

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D24165890

Pulled By: ailzhang

fbshipit-source-id: 72fe71ea95a738251b2fafc9eea4ab3831cf426b
  • Loading branch information
Ailing Zhang authored and facebook-github-bot committed Oct 16, 2020
1 parent d1ca7ef commit 8c629ec
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 125 deletions.
32 changes: 17 additions & 15 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(

// Add the kernel to the kernels list,
// possibly creating the list if this is the first kernel.
auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : catchAllKernel_;
// Redirect catchAll registrations to Math.
auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::Math];

if (k.size() > 0) {
TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " for dispatch key ", toString(dispatch_key), " that overwrote a previously registered kernel with the same dispatch key for the same operator.");
Expand All @@ -132,20 +133,17 @@ void OperatorEntry::deregisterKernel_(
c10::optional<DispatchKey> dispatch_key,
std::list<AnnotatedKernel>::iterator kernel
) {
if (dispatch_key.has_value()) {
auto found = kernels_.find(*dispatch_key);
TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
auto& k = found->second;
k.erase(kernel);
if (k.empty()) {
// the invariant says we don't want empty lists but instead remove the list from the map
kernels_.erase(found);
}
updateDispatchTable_(dispatcher, *dispatch_key);
} else {
catchAllKernel_.erase(kernel);
updateDispatchTableFull_(dispatcher);
}
// Redirect catchAll deregistrations to Math.
DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::Math;
auto found = kernels_.find(dk);
TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
auto& k = found->second;
k.erase(kernel);
if (k.empty()) {
// the invariant says we don't want empty lists but instead remove the list from the map
kernels_.erase(found);
}
updateDispatchTable_(dispatcher, dk);
}

void OperatorEntry::updateFallback(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
Expand Down Expand Up @@ -259,6 +257,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// fit 2.1 and we can remove 2.4 entirely.
if (!has_backend_kernel && !catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
// Prepare for catchAll removal, make sure it's not used in dispatchTable
TORCH_INTERNAL_ASSERT(false);
return {catchAllKernel_.front(), "catch all"};
}
}
Expand All @@ -272,6 +272,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// 4. Catch all
if (!catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
// Prepare for catchAll removal, make sure it's not used in dispatchTable
TORCH_INTERNAL_ASSERT(false);
return {catchAllKernel_.front(), "catch all"};
}

Expand Down
82 changes: 46 additions & 36 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,22 +777,6 @@ TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKer
EXPECT_TRUE(called);
}

TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) {
auto registrar = c10::Dispatcher::singleton().registerFallback(c10::DispatchKey::CPU, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>(), "");

auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options()
.catchAllKernel([] (Tensor, std::string) {
called = true;
}));
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());

called = false;
auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
EXPECT_FALSE(called);
EXPECT_EQ("hello _test::dummy", stack[1].toString()->string());
}

bool called_autograd = false;
bool called_nonautograd = false;

Expand Down Expand Up @@ -835,20 +819,6 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_th
EXPECT_TRUE(called_autograd);
}

TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) {
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
.kernel<decltype(nonautograd_kernel), nonautograd_kernel>(DispatchKey::CPU)
.kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());

called_nonautograd = called_autograd = false;
op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU));
EXPECT_TRUE(called_nonautograd);
EXPECT_FALSE(called_autograd);
}

TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallAutogradKernel) {
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
.catchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>()
Expand All @@ -857,10 +827,11 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_t
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());

// catchAll now maps to Math which has higher precedence than Autograd
called_nonautograd = called_autograd = false;
op->typed<void (Tensor)>().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true));
EXPECT_FALSE(called_nonautograd);
EXPECT_TRUE(called_autograd);
EXPECT_TRUE(called_nonautograd);
EXPECT_FALSE(called_autograd);
}

TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel) {
Expand Down Expand Up @@ -1627,6 +1598,39 @@ TEST(NewOperatorRegistrationTest, schema) {
ASSERT_TRUE(Dispatcher::singleton().findSchema({"test::def4", ""})->schema().isDefaultAliasAnalysisKind());
}

TEST(NewOperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) {
auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
m1.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());

bool called = false;
auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn(Tensor t, str input) -> ()");
m.impl("fn", [&] (Tensor, std::string) { called = true; });

auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());

called = false;
auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello ");
// CatchAll now maps to Math and has higher precedence than backend fallback.
EXPECT_TRUE(called);
}

TEST(NewOperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) {
auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn(Tensor dummy) -> ()");
m.impl("fn", c10::DispatchKey::CPU, nonautograd_kernel);
m.impl("fn", c10::DispatchKey::Autograd, autograd_kernel);

auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());

called_nonautograd = called_autograd = false;
callOp(*op, dummyTensor(DispatchKey::CPU));
EXPECT_TRUE(called_nonautograd);
EXPECT_FALSE(called_autograd);
}

TEST(NewOperatorRegistrationTest, dispatchWithMathKernel) {
bool math_called = false;
auto m = MAKE_TORCH_LIBRARY(test);
Expand Down Expand Up @@ -1708,18 +1712,20 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathAndCatchAllKernel) {
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());

// catchAll now maps to Math, which means we have two registrations to Math key.
// The last registration is used.
{
catchall_called = math_called = false;
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
ASSERT_TRUE(math_called);
ASSERT_FALSE(catchall_called);
ASSERT_FALSE(math_called);
ASSERT_TRUE(catchall_called);
}

{
catchall_called = math_called = false;
callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
ASSERT_TRUE(math_called);
ASSERT_FALSE(catchall_called);
ASSERT_FALSE(math_called);
ASSERT_TRUE(catchall_called);
}
}

Expand Down Expand Up @@ -2055,6 +2061,10 @@ TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther
TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
bool privateuse1_called = false;
bool catchall_called = false;
// Similar to in-tree AutogradCPU/AutogradCUDA etc, out-of-tree backends usually register
// a fallthrough kernel for AutogradPrivateUse1.
auto m1 = MAKE_TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1);
m1.fallback(CppFunction::makeFallthrough());

auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn", torch::dispatch(c10::DispatchKey::PrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; }));
Expand Down
5 changes: 0 additions & 5 deletions aten/src/ATen/templates/TypeDefault.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,4 @@ TORCH_LIBRARY(aten, m) {
TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) {
${default_backend_function_registrations};
}

TORCH_LIBRARY_IMPL(aten, Math, m) {
${math_function_registrations};
}

} // namespace at

0 comments on commit 8c629ec

Please sign in to comment.