diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 187b217604ba..80aa20459c61 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -499,7 +499,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { // supported because they were manually registered. I'm not sure // if these registrations are right or not, but they preserve old behavior // (and some of them are exercised by the test suite). - m.impl("backward", CppFunction::makeFallthrough()); + m.impl("_backward", CppFunction::makeFallthrough()); m.impl("set_data", CppFunction::makeFallthrough()); m.impl("data", CppFunction::makeFallthrough()); m.impl("is_leaf", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/native/VariableMethodStubs.cpp b/aten/src/ATen/native/VariableMethodStubs.cpp index 7d5cea725cf1..d547e42156ea 100644 --- a/aten/src/ATen/native/VariableMethodStubs.cpp +++ b/aten/src/ATen/native/VariableMethodStubs.cpp @@ -8,7 +8,7 @@ namespace at { namespace native { -void backward(const Tensor& self, const Tensor& gradient, c10::optional keep_graph, bool create_graph) { +void _backward(const Tensor& self, TensorList inputs, const Tensor& gradient, c10::optional keep_graph, bool create_graph) { AT_ERROR("backward is not implemented for Tensor"); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3fdd90b88f90..f9b422d97c89 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -46,7 +46,7 @@ variants: function # Computes the gradient of current tensor w.r.t. graph leaves. -- func: backward(Tensor self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () +- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () use_c10_dispatcher: hacky_wrapper_for_legacy_signatures manual_kernel_registration: True variants: method diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 1fd7eb7e5d9f..d44c8e9ed4fc 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -495,7 +495,7 @@ class CAFFE2_API Tensor { /// // f requires grad, has no operation creating it /// @endcode - /// \fn void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false) const; + /// \fn void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false, c10::optional inputs=c10::nullopt) const; /// /// Computes the gradient of current tensor with respect to graph leaves. /// @@ -522,6 +522,23 @@ class CAFFE2_API Tensor { /// \param create_graph If ``true``, graph of the derivative will /// be constructed, allowing to compute higher order derivative /// products. Defaults to ``false``. + /// \param inputs Inputs w.r.t. which the gradient will be accumulated into + /// ``at::Tensor::grad``. All other Tensors will be ignored. If not + /// provided, the gradient is accumulated into all the leaf Tensors + /// that were used to compute the current tensor. All the provided inputs + /// must be leaf Tensors. + void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false, c10::optional inputs=c10::nullopt) const { + // NB: Adding this wrapper to _backward here because we'd like our + // 'backwards' api to accept the 'inputs' argument optionally. Since code gen + // currently does not support optional of TensorList our approach is to replace + // backward in native_functions.yaml with _backward and call it here instead. + if (inputs.has_value()) { + TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty") + this->_backward(inputs.value(), gradient, retain_graph, create_graph); + } else { + this->_backward({}, gradient, retain_graph, create_graph); + } + } /// \fn Tensor detach() const; /// diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index bf72f2215972..81e530d0dbe7 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -762,6 +762,36 @@ TEST(CustomAutogradTest, HookNone) { ASSERT_TRUE(was_called); } +TEST(CustomAutogradTest, BackwardWithInputs) { + Variable x = torch::randn({5,5}, torch::requires_grad()); + Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable z = x * x + x * y + y * y; + Variable x_grad_expected = 2 * x + y; + Variable y_grad_expected = x + 2 * y; + + z.backward(torch::ones({5, 5}), false, false, {x}); + + ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected); + ASSERT_FALSE(y.grad().defined()); +} + +TEST(CustomAutogradTest, BackwardWithEmptyInputs) { + Variable x = torch::randn({5,5}, torch::requires_grad()); + Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable z = x * x + x * y + y * y; + Variable x_grad_expected = 2 * x + y; + Variable y_grad_expected = x + 2 * y; + ASSERT_THROWS_WITH(z.backward(torch::ones({5, 5}), false, false, std::vector{}), "cannot be empty"); +} + +TEST(CustomAutogradTest, BackwardWithNonLeafInputs) { + Variable x = torch::randn({5,5}, torch::requires_grad()); + Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable z = x * x; + Variable w = z + x * y + y * y; + ASSERT_THROWS_WITH(w.backward(torch::ones({5, 5}), false, false, {z}), "is not a leaf Tensor"); +} + // TODO add these tests if needed // test_once_differentiable // test_sparse_backward diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 45cc6bd2b18a..9fbf2ba5889c 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -113,7 +113,7 @@ # You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp MANUAL_BACKEND = set([ 'options', 'data', 'set_data', 'is_leaf', 'output_nr', '_version', 'retain_grad', - 'backward', 'requires_grad_', + '_backward', 'requires_grad_', ]) # For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. diff --git a/torch/csrc/autograd/TraceTypeManual.cpp b/torch/csrc/autograd/TraceTypeManual.cpp index 625f3c40e9a1..7e1f762a96a9 100644 --- a/torch/csrc/autograd/TraceTypeManual.cpp +++ b/torch/csrc/autograd/TraceTypeManual.cpp @@ -131,7 +131,7 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) { m.impl("copy_", copy_); // Skip tracing for the following ops by registering fallthrough kernel explicitly. - m.impl("backward", CppFunction::makeFallthrough()); + m.impl("_backward", CppFunction::makeFallthrough()); m.impl("set_data", CppFunction::makeFallthrough()); m.impl("data", CppFunction::makeFallthrough()); m.impl("is_leaf", CppFunction::makeFallthrough()); diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index bb0bed93e22c..6707741f5b8d 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -83,15 +83,17 @@ std::vector unpack(at::TensorList tl, const char *name, int pos) { namespace { -void backward( +void _backward( const Tensor& self, + TensorList inputs, const c10::optional& gradient, c10::optional keep_graph, bool create_graph) { // TODO torch::autograd::backward should take the c10::optional gradient directly // instead of us having to unwrap it to Tensor _gradient here. Tensor _gradient = gradient.has_value() ? *gradient : Tensor(); - torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph); + std::vector input_vars(inputs.begin(), inputs.end()); + torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph, input_vars); } void set_data(Tensor & self, const Tensor & new_data) { @@ -317,12 +319,12 @@ TORCH_LIBRARY_IMPL(aten, Autograd, m) { // tensor argument. // TODO Once callBoxed() supports optional tensor arguments, we can enable `use_c10_dispatcher: full` for backward() // and requires_grad_(), then remove the backend Autograd kernel here, only leaving the Math kernel. - m.impl("backward", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::backward))); + m.impl("_backward", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_backward))); m.impl("requires_grad_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::requires_grad_))); } TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) { - m.impl("backward", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::backward))); + m.impl("_backward", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::_backward))); m.impl("requires_grad_", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::requires_grad_))); } diff --git a/torch/csrc/autograd/autograd.cpp b/torch/csrc/autograd/autograd.cpp index 9c2879aa460d..858b329979ef 100644 --- a/torch/csrc/autograd/autograd.cpp +++ b/torch/csrc/autograd/autograd.cpp @@ -93,6 +93,12 @@ variable_list run_backward( if (!grad_fn) { grad_fn = impl::try_get_grad_accumulator(input); } + if (accumulate_grad) { + TORCH_CHECK( + input.is_leaf(), + "One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor" + ) + } TORCH_CHECK( input.requires_grad(), "One of the differentiated Tensors does not require grad"); @@ -125,12 +131,13 @@ void backward( const variable_list& tensors, const variable_list& grad_tensors, c10::optional retain_graph, - bool create_graph) { + bool create_graph, + const variable_list& inputs) { variable_list gradients = _make_grads(tensors, grad_tensors); if (!retain_graph) { retain_graph = create_graph; } - run_backward(tensors, gradients, retain_graph.value(), create_graph, {}, /*allow_unused=*/true, /*accumulate_grad=*/true); + run_backward(tensors, gradients, retain_graph.value(), create_graph, inputs, /*allow_unused=*/true, /*accumulate_grad=*/true); } variable_list grad( diff --git a/torch/csrc/autograd/autograd.h b/torch/csrc/autograd/autograd.h index 90b8fcd3c4b4..c1e788fd6131 100644 --- a/torch/csrc/autograd/autograd.h +++ b/torch/csrc/autograd/autograd.h @@ -32,11 +32,16 @@ namespace autograd { /// value of `create_graph`. /// \param create_graph If `true`, graph of the derivative will be constructed, allowing /// to compute higher order derivative products. Defaults to `false`. +/// \param inputs Inputs w.r.t. which the gradient will be accumulated into +/// `at::Tensor::grad`. All other Tensors will be ignored. If not provided, the gradient +/// is accumulated into all the leaf Tensors that were used to compute param `tensors`. +/// All the provided inputs must be leaf Tensors. TORCH_API void backward( const variable_list& tensors, const variable_list& grad_tensors = {}, c10::optional retain_graph = c10::nullopt, - bool create_graph = false); + bool create_graph = false, + const variable_list& inputs = {}); /// Computes and returns the sum of gradients of outputs with respect to the inputs. ///