Skip to content

Commit

Permalink
Don't use NonVariableTypeMode in custom ops (#37355)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #37355

Potentially fixes #37306
ghstack-source-id: 103073537

Test Plan: waitforsandcastle

Differential Revision: D21261946

fbshipit-source-id: 454652b528dcf942bec5438f89201822de40bbf0
  • Loading branch information
smessmer committed Apr 30, 2020
1 parent 96f218d commit 85026e4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
9 changes: 2 additions & 7 deletions aten/src/ATen/core/VariableFallbackKernel.cpp
Expand Up @@ -27,14 +27,9 @@ using c10::KernelFunction;

namespace {

void variable_fallback_kernel(const OperatorHandle& op, Stack* stack) {
at::AutoNonVariableTypeMode _var_guard(true);
Dispatcher::singleton().callBoxed(op, stack);
}

static auto registry = Dispatcher::singleton().registerBackendFallbackKernel(
static auto registry = c10::Dispatcher::singleton().registerBackendFallbackKernel(
DispatchKey::VariableTensorId,
KernelFunction::makeFromBoxedFunction<&variable_fallback_kernel>()
KernelFunction::makeFallthrough()
);

}
26 changes: 26 additions & 0 deletions test/test_cpp_extensions_jit.py
Expand Up @@ -13,6 +13,7 @@
import torch.backends.cudnn
import torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDA_HOME
from torch.autograd.gradcheck import gradcheck


TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
Expand Down Expand Up @@ -806,6 +807,31 @@ def backward(ctx, gx):
loss = MyFn.apply(inp).sum()
test_backward_deadlock.run_back_no_gil(loss)

def test_custom_compound_op_autograd(self):
# Test that a custom compound op (i.e. a custom op that just calls other aten ops)
# correctly returns gradients of those other ops

source = """
#include <torch/script.h>
torch::Tensor my_add(torch::Tensor x, torch::Tensor y) {
return x + y;
}
static auto registry = torch::import()
.def("add", &my_add);
"""

torch.utils.cpp_extension.load_inline(
name="is_python_module",
cpp_sources=source,
verbose=True,
is_python_module=False,
)

a = torch.randn(5, 5, requires_grad=True)
b = torch.randn(5, 5, requires_grad=True)

gradcheck(torch.ops.my.add, [a, b], eps=1e-2)


if __name__ == "__main__":
common.run_tests()

0 comments on commit 85026e4

Please sign in to comment.