Skip to content

Commit

Permalink
Don't use NonVariableTypeMode in custom ops
Browse files Browse the repository at this point in the history
Pull Request resolved: #37355

Potentially fixes #37306
ghstack-source-id: 103073537

Differential Revision: [D21261946](https://our.internmc.facebook.com/intern/diff/D21261946/)
  • Loading branch information
smessmer committed Apr 28, 2020
1 parent 22ac071 commit 0ec1e0c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 39 deletions.
39 changes: 0 additions & 39 deletions aten/src/ATen/core/VariableFallbackKernel.cpp
Expand Up @@ -28,47 +28,8 @@ using c10::KernelFunction;

namespace {

void variable_fallback_kernel(const OperatorHandle& op, Stack* stack) {
at::AutoNonVariableTypeMode _var_guard(true);
op.callBoxed(stack);
}

TORCH_LIBRARY_IMPL(_, Autograd, m) {
#ifdef C10_MOBILE
// As custom mobile build might not include variable kernels, we need
// leverage variable fallback mechanism as well. The goals are:
// 1) don't break forward pass for inference-only mobile build;
// 2) don't break forward/backward pass for mobile build with necessary
// variable kernels registered;
//
// This `fallthrough` kernel is for #1 - because not all kernels support
// boxed call yet, registering `variable_fallback_kernel` might fail.
// When an op has variable kernel registered explicitly dispatcher will
// call it instead of `fallthrough`, so `fallthrough` won't break
// dispatching to real variable kernels for case #2.
//
// The substantial difference between fallback and fallthrough is whether
// AutoNonVariableTypeMode guard is applied. There are two downstream
// effects of the guard:
// a) stop calling variable kernels of other ops called by the current op;
// For case #1, there is no difference because no variable kernels are
// registered. For case #2, there is no difference as long as ALL used
// ops have real variable kernels registered, where the guard will be
// set properly in real variable kernels. There is potential issue only
// when variable kernels are partially registered for used ops.
// b) `variable_excluded_from_dispatch()` method returns the state of the
// NonVariableTypeMode. As of when this diff is written, the callers of
// the method are ALL asserting it returns true; the only exception is
// the deprecated `is_variable()` method. So we make the method to always
// return true for mobile builds. It shouldn't break case #1/#2 as long
// as `is_variable()` is not used.
//
// We can remove this `fallthrough` kernel when all kernels support boxed
// call.
m.fallback(torch::CppFunction::makeFallthrough());
#else
m.fallback(torch::CppFunction::makeFromBoxedFunction<&variable_fallback_kernel>());
#endif
}

}
27 changes: 27 additions & 0 deletions test/test_cpp_extensions_jit.py
Expand Up @@ -13,6 +13,7 @@
import torch.backends.cudnn
import torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
from torch.autograd.gradcheck import gradcheck


TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
Expand Down Expand Up @@ -808,6 +809,32 @@ def backward(ctx, gx):
loss = MyFn.apply(inp).sum()
test_backward_deadlock.run_back_no_gil(loss)

def test_custom_compound_op_autograd(self):
# Test that a custom compound op (i.e. a custom op that just calls other aten ops)
# correctly returns gradients of those other ops

source = """
#include <torch/library.h>
torch::Tensor my_add(torch::Tensor x, torch::Tensor y) {
return x + y;
}
TORCH_LIBRARY(my, m) {
m.def("add", &my_add);
}
"""

torch.utils.cpp_extension.load_inline(
name="is_python_module",
cpp_sources=source,
verbose=True,
is_python_module=False,
)

a = torch.randn(5, 5, requires_grad=True)
b = torch.randn(5, 5, requires_grad=True)

gradcheck(torch.ops.my.add, [a, b], eps=1e-2)


if __name__ == "__main__":
common.run_tests()

0 comments on commit 0ec1e0c

Please sign in to comment.