Skip to content

Commit

Permalink
Improve calling backward() and grad() inside vmap error messages (#42876
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #42876

Previously, the error messages were pretty bad. This PR adds nice
error messages for the following cases:
- user attempts to call .backward() inside vmap for any reason
whatsoever
- user attempts to call autograd.grad(outputs, inputs, grad_outputs),
where outputs or inputs is being vmapped over (so they are
BatchedTensors).

The case we do support is calling autograd.grad(outputs, inputs,
grad_outputs) where `grad_outputs` is being vmapped over. This is the
case for batched gradient support (e.g., user passes in a batched
grad_output).

Test Plan: - new tests: `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23059836

Pulled By: zou3519

fbshipit-source-id: 2fd4e3fd93f558e67e2f0941b18f0d00d8ab439f
  • Loading branch information
zou3519 authored and facebook-github-bot committed Aug 12, 2020
1 parent 5c39146 commit bda0007
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
57 changes: 57 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,63 @@ def test_fallback_multiple_returns(self):
expected = torch.var_mean(tensor, dim=3)
self.assertEqual(result, expected)

def test_backward_unsupported_interaction(self):
x = torch.randn(3, requires_grad=True)
y = torch.randn(5)
grad = torch.randn_like(x)
err_msg = r'backward\(\) called inside torch.vmap'

def backward_on_vmapped_tensor(x):
x.sum().backward()

with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(backward_on_vmapped_tensor)(x)

def backward_with_vmapped_grad(x, grad):
x.backward(grad)

with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(backward_with_vmapped_grad)(x, grad)

def completely_unrelated_backward(y):
x.sum().backward()

with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(completely_unrelated_backward)(y)

def test_grad_unsupported_interaction(self):
input_tensor = torch.randn(3, requires_grad=True)
err_msg = 'autograd.grad.* called inside torch.vmap'

captured = torch.randn(3, requires_grad=True)

def output_to_grad_is_vmapped(input_tensor):
output = (captured * input_tensor).sum()
return torch.autograd.grad([output], [captured])[0]

with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(output_to_grad_is_vmapped)(input_tensor)

output = (input_tensor ** 2).sum()

def input_to_grad_is_vmapped(input_tensor):
return torch.autograd.grad([output], [input_tensor])[0]

with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(input_to_grad_is_vmapped)(input_tensor)

def test_batched_gradient_basic(self):
N = 3
x = torch.randn(N, requires_grad=True)
y = torch.randn(N)

def vjp_mul(v):
return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]

batched_v = torch.eye(N)
jacobian = vmap(vjp_mul)(batched_v)
self.assertEqual(jacobian, torch.diagflat(y))


def slice_inputs(inputs, bdims, i):
result = []
Expand Down
25 changes: 23 additions & 2 deletions torch/csrc/autograd/python_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/python_anomaly_mode.h>
#include <torch/csrc/autograd/python_function.h>
#include <ATen/BatchedTensorImpl.h>
#include <ATen/VmapMode.h>
#include <pybind11/pybind11.h>

#ifndef _WIN32
Expand Down Expand Up @@ -143,6 +145,13 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld "
"gradients", num_tensors, num_gradients);

// The user either called autograd.backward(...) or autograd.grad(...) to get here
bool backward_api_called = inputs == nullptr;
TORCH_CHECK(!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
"backward() called inside torch.vmap. This is not supported, "
"please call backward() outside torch.vmap or instead use "
"torch.autograd.grad inside torch.vmap");

edge_list roots;
roots.reserve(num_tensors);
variable_list grads;
Expand All @@ -152,6 +161,12 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
"tuple is not a Tensor", i);
auto& variable = ((THPVariable*)_tensor)->cdata;
TORCH_CHECK(!isBatchedTensor(variable),
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
"torch.vmap. We do not support the case where any outputs are ",
"vmapped tensors (output ", i, " is being vmapped over). Please "
"call autograd.grad() outside torch.vmap or file a bug report "
"with your use case.")
if(variable.is_complex()) {
TORCH_WARN_ONCE("Complex backward is not fully supported yet and could lead to wrong ",
"gradients for functions we have not fixed yet");
Expand Down Expand Up @@ -181,14 +196,20 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
}

std::vector<Edge> output_edges;
if (inputs != nullptr) {
if (!backward_api_called) {
int num_inputs = PyTuple_GET_SIZE(inputs);
output_edges.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Tensors, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
TORCH_CHECK(!isBatchedTensor(input_var->cdata),
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
"torch.vmap. We do not support the case where any inputs are ",
"vmapped tensors (input ", i, " is being vmapped over). Please "
"call autograd.grad() outside torch.vmap or file a bug report "
"with your use case.")
const auto output_nr = input_var->cdata.output_nr();
auto grad_fn = input_var->cdata.grad_fn();
if (!grad_fn) {
Expand All @@ -211,7 +232,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
}

if (inputs != nullptr) {
if (!backward_api_called) {
int num_inputs = PyTuple_GET_SIZE(inputs);
THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
if (!py_outputs) return nullptr;
Expand Down

0 comments on commit bda0007

Please sign in to comment.