Skip to content

Commit

Permalink
Add input argument to autograd.backward() cpp api (#47214)
Browse files Browse the repository at this point in the history
Summary:
Helps fix #46373 for the cpp api.

Follow up to #46855 which only changed the api for python only

Pull Request resolved: #47214

Reviewed By: agolynski

Differential Revision: D24716139

Pulled By: soulitzer

fbshipit-source-id: 3e1f35968e8dee132985b883481cfd0d1872ccdd
  • Loading branch information
soulitzer authored and facebook-github-bot committed Nov 4, 2020
1 parent 6f60251 commit 2e5bfa9
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 13 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/NamedRegistrations.cpp
Expand Up @@ -499,7 +499,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
// supported because they were manually registered. I'm not sure
// if these registrations are right or not, but they preserve old behavior
// (and some of them are exercised by the test suite).
m.impl("backward", CppFunction::makeFallthrough());
m.impl("_backward", CppFunction::makeFallthrough());
m.impl("set_data", CppFunction::makeFallthrough());
m.impl("data", CppFunction::makeFallthrough());
m.impl("is_leaf", CppFunction::makeFallthrough());
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/VariableMethodStubs.cpp
Expand Up @@ -8,7 +8,7 @@
namespace at {
namespace native {

void backward(const Tensor& self, const Tensor& gradient, c10::optional<bool> keep_graph, bool create_graph) {
void _backward(const Tensor& self, TensorList inputs, const Tensor& gradient, c10::optional<bool> keep_graph, bool create_graph) {
AT_ERROR("backward is not implemented for Tensor");
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -46,7 +46,7 @@
variants: function

# Computes the gradient of current tensor w.r.t. graph leaves.
- func: backward(Tensor self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
manual_kernel_registration: True
variants: method
Expand Down
19 changes: 18 additions & 1 deletion aten/src/ATen/templates/TensorBody.h
Expand Up @@ -495,7 +495,7 @@ class CAFFE2_API Tensor {
/// // f requires grad, has no operation creating it
/// @endcode

/// \fn void backward(const Tensor & gradient={}, c10::optional<bool> retain_graph=c10::nullopt, bool create_graph=false) const;
/// \fn void backward(const Tensor & gradient={}, c10::optional<bool> retain_graph=c10::nullopt, bool create_graph=false, c10::optional<TensorList> inputs=c10::nullopt) const;
///
/// Computes the gradient of current tensor with respect to graph leaves.
///
Expand All @@ -522,6 +522,23 @@ class CAFFE2_API Tensor {
/// \param create_graph If ``true``, graph of the derivative will
/// be constructed, allowing to compute higher order derivative
/// products. Defaults to ``false``.
/// \param inputs Inputs w.r.t. which the gradient will be accumulated into
/// ``at::Tensor::grad``. All other Tensors will be ignored. If not
/// provided, the gradient is accumulated into all the leaf Tensors
/// that were used to compute the current tensor. All the provided inputs
/// must be leaf Tensors.
void backward(const Tensor & gradient={}, c10::optional<bool> retain_graph=c10::nullopt, bool create_graph=false, c10::optional<TensorList> inputs=c10::nullopt) const {
// NB: Adding this wrapper to _backward here because we'd like our
// 'backwards' api to accept the 'inputs' argument optionally. Since code gen
// currently does not support optional of TensorList our approach is to replace
// backward in native_functions.yaml with _backward and call it here instead.
if (inputs.has_value()) {
TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty")
this->_backward(inputs.value(), gradient, retain_graph, create_graph);
} else {
this->_backward({}, gradient, retain_graph, create_graph);
}
}

/// \fn Tensor detach() const;
///
Expand Down
30 changes: 30 additions & 0 deletions test/cpp/api/autograd.cpp
Expand Up @@ -762,6 +762,36 @@ TEST(CustomAutogradTest, HookNone) {
ASSERT_TRUE(was_called);
}

TEST(CustomAutogradTest, BackwardWithInputs) {
Variable x = torch::randn({5,5}, torch::requires_grad());
Variable y = torch::randn({5,5}, torch::requires_grad());
Variable z = x * x + x * y + y * y;
Variable x_grad_expected = 2 * x + y;
Variable y_grad_expected = x + 2 * y;

z.backward(torch::ones({5, 5}), false, false, {x});

ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
ASSERT_FALSE(y.grad().defined());
}

TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
Variable x = torch::randn({5,5}, torch::requires_grad());
Variable y = torch::randn({5,5}, torch::requires_grad());
Variable z = x * x + x * y + y * y;
Variable x_grad_expected = 2 * x + y;
Variable y_grad_expected = x + 2 * y;
ASSERT_THROWS_WITH(z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}), "cannot be empty");
}

TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
Variable x = torch::randn({5,5}, torch::requires_grad());
Variable y = torch::randn({5,5}, torch::requires_grad());
Variable z = x * x;
Variable w = z + x * y + y * y;
ASSERT_THROWS_WITH(w.backward(torch::ones({5, 5}), false, false, {z}), "is not a leaf Tensor");
}

// TODO add these tests if needed
// test_once_differentiable
// test_sparse_backward
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -113,7 +113,7 @@
# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
MANUAL_BACKEND = set([
'options', 'data', 'set_data', 'is_leaf', 'output_nr', '_version', 'retain_grad',
'backward', 'requires_grad_',
'_backward', 'requires_grad_',
])

# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/TraceTypeManual.cpp
Expand Up @@ -131,7 +131,7 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) {
m.impl("copy_", copy_);

// Skip tracing for the following ops by registering fallthrough kernel explicitly.
m.impl("backward", CppFunction::makeFallthrough());
m.impl("_backward", CppFunction::makeFallthrough());
m.impl("set_data", CppFunction::makeFallthrough());
m.impl("data", CppFunction::makeFallthrough());
m.impl("is_leaf", CppFunction::makeFallthrough());
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/autograd/VariableTypeManual.cpp
Expand Up @@ -83,15 +83,17 @@ std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos) {

namespace {

void backward(
void _backward(
const Tensor& self,
TensorList inputs,
const c10::optional<Tensor>& gradient,
c10::optional<bool> keep_graph,
bool create_graph) {
// TODO torch::autograd::backward should take the c10::optional<Tensor> gradient directly
// instead of us having to unwrap it to Tensor _gradient here.
Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph);
std::vector<torch::autograd::Variable> input_vars(inputs.begin(), inputs.end());
torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph, input_vars);
}

void set_data(Tensor & self, const Tensor & new_data) {
Expand Down Expand Up @@ -317,12 +319,12 @@ TORCH_LIBRARY_IMPL(aten, Autograd, m) {
// tensor argument.
// TODO Once callBoxed() supports optional tensor arguments, we can enable `use_c10_dispatcher: full` for backward()
// and requires_grad_(), then remove the backend Autograd kernel here, only leaving the Math kernel.
m.impl("backward", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::backward)));
m.impl("_backward", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_backward)));
m.impl("requires_grad_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::requires_grad_)));
}

TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) {
m.impl("backward", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::backward)));
m.impl("_backward", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::_backward)));
m.impl("requires_grad_", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::requires_grad_)));
}

Expand Down
11 changes: 9 additions & 2 deletions torch/csrc/autograd/autograd.cpp
Expand Up @@ -93,6 +93,12 @@ variable_list run_backward(
if (!grad_fn) {
grad_fn = impl::try_get_grad_accumulator(input);
}
if (accumulate_grad) {
TORCH_CHECK(
input.is_leaf(),
"One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor"
)
}
TORCH_CHECK(
input.requires_grad(),
"One of the differentiated Tensors does not require grad");
Expand Down Expand Up @@ -125,12 +131,13 @@ void backward(
const variable_list& tensors,
const variable_list& grad_tensors,
c10::optional<bool> retain_graph,
bool create_graph) {
bool create_graph,
const variable_list& inputs) {
variable_list gradients = _make_grads(tensors, grad_tensors);
if (!retain_graph) {
retain_graph = create_graph;
}
run_backward(tensors, gradients, retain_graph.value(), create_graph, {}, /*allow_unused=*/true, /*accumulate_grad=*/true);
run_backward(tensors, gradients, retain_graph.value(), create_graph, inputs, /*allow_unused=*/true, /*accumulate_grad=*/true);
}

variable_list grad(
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/autograd/autograd.h
Expand Up @@ -32,11 +32,16 @@ namespace autograd {
/// value of `create_graph`.
/// \param create_graph If `true`, graph of the derivative will be constructed, allowing
/// to compute higher order derivative products. Defaults to `false`.
/// \param inputs Inputs w.r.t. which the gradient will be accumulated into
/// `at::Tensor::grad`. All other Tensors will be ignored. If not provided, the gradient
/// is accumulated into all the leaf Tensors that were used to compute param `tensors`.
/// All the provided inputs must be leaf Tensors.
TORCH_API void backward(
const variable_list& tensors,
const variable_list& grad_tensors = {},
c10::optional<bool> retain_graph = c10::nullopt,
bool create_graph = false);
bool create_graph = false,
const variable_list& inputs = {});

/// Computes and returns the sum of gradients of outputs with respect to the inputs.
///
Expand Down

0 comments on commit 2e5bfa9

Please sign in to comment.