Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/cpp/api/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ TEST(CustomAutogradTest, FunctionReturnsUndefined) {
};

auto x = torch::ones(1, torch::requires_grad());

MyFunction::apply(x).backward();
ASSERT_FALSE(x.grad().defined());

Expand Down
1 change: 1 addition & 0 deletions test/cpp/jit/test_autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ variable_list grad(
grad_outputs,
true,
false,
false,
fmap(inputs, get_edge));
}

Expand Down
58 changes: 58 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,64 @@ def call_backwards():
torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)])
self.assertRaises(RuntimeError, call_backwards)

def test_backward_with_inputs(self):
x = torch.randn(2, 2, requires_grad=True)
y = torch.randn(2, 2, requires_grad=True)

def fn():
return x ** 2 + y * x + y ** 2

gradient = torch.ones(2, 2)
x_grad_expected = 2 * x + y
y_grad_expected = x + 2 * y

@torch.no_grad()
def reset_grad():
x.grad.zero_()
y.grad.zero_()

torch.autograd.backward(fn(), gradient, inputs=[x, y])
self.assertEqual(x.grad, x_grad_expected)
self.assertEqual(y.grad, y_grad_expected)

reset_grad()
torch.autograd.backward(fn(), gradient, inputs=[x])
self.assertEqual(x.grad, x_grad_expected)
self.assertEqual(y.grad, torch.zeros(2, 2))

reset_grad()
torch.autograd.backward(fn(), gradient, inputs=[y])
self.assertEqual(y.grad, y_grad_expected)
self.assertEqual(x.grad, torch.zeros(2, 2))

reset_grad()
self.assertRaisesRegex(RuntimeError, 'cannot be empty',
lambda: torch.autograd.backward(fn(), gradient, inputs=[]))

def test_backward_with_nonleaf_inputs(self):
x = torch.randn(2, 2, requires_grad=True)
x_nonleaf = x * 1
y = torch.randn(2, 2, requires_grad=True)
z = torch.randn(2, 2, requires_grad=True)

out = x_nonleaf ** 2 + y * x_nonleaf + y ** 2

out.backward(torch.ones(2, 2), create_graph=True, inputs=[x, y])
x_grad_expected = 2 * x + y
y_grad_expected = x + 2 * y

self.assertEqual(y.grad, y_grad_expected)
self.assertEqual(x.grad, x_grad_expected)

self.assertRaisesRegex(RuntimeError, 'not a leaf Tensor',
lambda: out.backward(torch.ones(2, 2), create_graph=True, inputs=[x, y, x_nonleaf]))

# backward doesn't have an allow_unused flag, so the behavior of backward
# when variable is not part of the graph is as if allow_used were true
# x.grad will simply be None.
out.backward(torch.ones(2, 2), create_graph=True, inputs=[z])
self.assertIsNone(z.grad)

def test_dependent_backward(self):
x = torch.randn(10, requires_grad=True)
y = x ** 2
Expand Down
15 changes: 12 additions & 3 deletions torch/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def backward(
retain_graph: Optional[bool] = None,
create_graph: bool = False,
grad_variables: Optional[_TensorOrTensors] = None,
inputs: Optional[Sequence[torch.Tensor]] = None,
) -> None:
r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.

Expand Down Expand Up @@ -116,6 +117,11 @@ def backward(
create_graph (bool, optional): If ``True``, graph of the derivative will
be constructed, allowing to compute higher order derivative products.
Defaults to ``False``.
inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
accumulated into ``.grad``. All other Tensors will be ignored. If not
provided, the gradient is accumulated into all the leaf Tensors that were
used to compute the attr::tensors. All the provided inputs must be leaf
Tensors.
"""
if grad_variables is not None:
warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
Expand All @@ -125,17 +131,20 @@ def backward(
raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
"arguments both passed to backward(). Please only "
"use 'grad_tensors'.")
if inputs is not None and len(inputs) == 0:
raise RuntimeError("'inputs' argument to backward() cannot be empty.")

tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
inputs = tuple(inputs) if inputs is not None else tuple()

grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
grad_tensors_ = _make_grads(tensors, grad_tensors_)
if retain_graph is None:
retain_graph = create_graph

Variable._execution_engine.run_backward(
tensors, grad_tensors_, retain_graph, create_graph,
allow_unreachable=True) # allow_unreachable flag
tensors, grad_tensors_, retain_graph, create_graph, inputs,
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag


def grad(
Expand Down Expand Up @@ -213,7 +222,7 @@ def grad(

return Variable._execution_engine.run_backward(
outputs, grad_outputs_, retain_graph, create_graph,
inputs, allow_unused)
inputs, allow_unused, accumulate_grad=False)


# This function applies in case of gradient checkpointing for memory
Expand Down
9 changes: 5 additions & 4 deletions torch/csrc/autograd/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ variable_list run_backward(
bool keep_graph,
bool create_graph,
const variable_list& inputs,
bool allow_unused) {
bool allow_unused,
bool accumulate_grad) {
size_t num_tensors = outputs.size();
edge_list roots;
roots.reserve(num_tensors);
Expand Down Expand Up @@ -104,7 +105,7 @@ variable_list run_backward(
}

variable_list grad_inputs = Engine::get_default_engine().execute(
roots, grad_outputs, keep_graph, create_graph, output_edges);
roots, grad_outputs, keep_graph, create_graph, accumulate_grad, output_edges);
// check if grad_inputs contains None or not base on the allow_unused flag
if (!inputs.empty() && !allow_unused) {
size_t num_inputs = inputs.size();
Expand All @@ -129,7 +130,7 @@ void backward(
if (!retain_graph) {
retain_graph = create_graph;
}
run_backward(tensors, gradients, retain_graph.value(), create_graph, {}, /*allow_unused=*/true);
run_backward(tensors, gradients, retain_graph.value(), create_graph, {}, /*allow_unused=*/true, /*accumulate_grad=*/true);
}

variable_list grad(
Expand All @@ -144,7 +145,7 @@ variable_list grad(
retain_graph = create_graph;
}
return run_backward(
outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused);
outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused, /*accumulate_grad=*/false);
}

} // namespace autograd
Expand Down
18 changes: 12 additions & 6 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ auto Engine::execute(const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs) -> variable_list {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
validate_outputs(roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
Expand All @@ -867,7 +868,7 @@ auto Engine::execute(const edge_list& roots,
compute_dependencies(graph_root.get(), *graph_task);

if (!outputs.empty()) {
graph_task->init_to_execute(*graph_root, outputs);
graph_task->init_to_execute(*graph_root, outputs, accumulate_grad);
}

execute_with_graph_task(graph_task, graph_root);
Expand Down Expand Up @@ -1079,16 +1080,21 @@ void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
thread_pool_shared_->work_.notify_one();
}

void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs) {
void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad) {
exec_info_[&graph_root].needed_ = true;

int output_idx = 0;
for (auto & output_edge : outputs) {
Node *output = output_edge.function.get();
auto & info = exec_info_[output];
if (!info.captures_)
info.captures_ = make_unique<std::vector<ExecInfo::Capture>>();
info.captures_->emplace_back(output_edge.input_nr, output_idx++);
if (accumulate_grad) {
info.needed_ = true;
} else {
if (!info.captures_) {
info.captures_ = make_unique<std::vector<ExecInfo::Capture>>();
}
info.captures_->emplace_back(output_edge.input_nr, output_idx++);
}
}
captured_vars_.resize(output_idx);

Expand Down Expand Up @@ -1136,7 +1142,7 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs) {
auto it = exec_info_.find(edge.function.get());
return it != exec_info_.end() && it->second.should_execute();
});
exec_info_[frame.fn_].needed_ = needed;
exec_info_[frame.fn_].needed_ |= needed;
stack.pop_back();
}
}
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/autograd/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct GraphTask: std::enable_shared_from_this<GraphTask> {

std::unordered_set<c10::Stream> leaf_streams;

void init_to_execute(Node& graph_root, const edge_list& outputs);
void init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad);

// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
Expand Down Expand Up @@ -272,6 +272,7 @@ struct TORCH_API Engine {
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs = {});

// Given a pre-populated GraphTask and GraphRoot, computes the backward pass
Expand Down
27 changes: 16 additions & 11 deletions torch/csrc/autograd/python_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,14 @@ variable_list PythonEngine::execute(
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs) {
TORCH_CHECK(!PyGILState_Check(), "The autograd engine was called while holding the GIL. If you are using the C++ "
"API, the autograd engine is an expensive operation that does not require the "
"GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'"
". If you are not using the C++ API, please report a bug to the pytorch team.")
try {
return Engine::execute(roots, inputs, keep_graph, create_graph, outputs);
return Engine::execute(roots, inputs, keep_graph, create_graph, accumulate_grad, outputs);
} catch (python_error& e) {
e.restore();
throw;
Expand Down Expand Up @@ -128,14 +129,14 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg
unsigned char create_graph = 0;
PyObject *inputs = nullptr;
unsigned char allow_unreachable = 0;
const char *accepted_kwargs[] = {
unsigned char accumulate_grad = 0;
const char *accepted_kwargs[] = { // NOLINT
"tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
"allow_unreachable", nullptr
"allow_unreachable", "accumulate_grad", nullptr
};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs,
&tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Obb", (char**)accepted_kwargs,
&tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable, &accumulate_grad))
return nullptr;

THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to "
"be a tuple, but got %s", THPUtils_typename(tensors));
THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is "
Expand All @@ -147,7 +148,7 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg
"gradients", num_tensors, num_gradients);

// The user either called autograd.backward(...) or autograd.grad(...) to get here
bool backward_api_called = inputs == nullptr;
bool backward_api_called = accumulate_grad;
TORCH_CHECK(!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zou3519 just to confirm that this is fine from a vmap point of view?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does accumulate_grad do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply - accumulate_grad indicates whether run_backward was called from the backward api or grad api, so the replacement seemed okay at a glance. Just wanted to confirm if vmap check depends specifically on inputs being nullptr.

"backward() called inside torch.vmap. This is not supported, "
"please call backward() outside torch.vmap or instead use "
Expand Down Expand Up @@ -193,7 +194,7 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg
}

std::vector<Edge> output_edges;
if (!backward_api_called) {
if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
output_edges.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
Expand All @@ -210,7 +211,11 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg
const auto output_nr = input_var->cdata.output_nr();
auto grad_fn = input_var->cdata.grad_fn();
if (!grad_fn) {
grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata);
grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata);
}
if (accumulate_grad) {
THPUtils_assert(input_var->cdata.is_leaf(),
"One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor");
}
THPUtils_assert(input_var->cdata.requires_grad(),
"One of the differentiated Tensors does not require grad");
Expand All @@ -226,10 +231,10 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg
{
pybind11::gil_scoped_release no_gil;
auto& engine = python::PythonEngine::get_python_engine();
outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
}

if (!backward_api_called) {
if (!backward_api_called && inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
if (!py_outputs) return nullptr;
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/python_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct PythonEngine : public Engine {
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment here for what accumulate_grad does, for future code readers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a separate PR for this

const edge_list& outputs = {}) override;

std::shared_ptr<at::ivalue::Future> execute_with_graph_task(
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/autograd/engine/dist_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ void DistEngine::computeDependencies(

// Create a dummy GraphRoot and run init_to_execute with it.
GraphRoot dummyRoot(edges, {});
graphTask->init_to_execute(dummyRoot, outputEdges);
graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false);
for (auto& mapEntry : graphTask->exec_info_) {
auto& execInfo = mapEntry.second;
if (!execInfo.captures_) {
Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.apply_: lambda self, callable: -1,
Tensor.as_strided: lambda self, size, stride: -1,
Tensor.as_strided_: lambda self, size, stride: -1,
Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False: -1,
Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
Expand Down
12 changes: 9 additions & 3 deletions torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __repr__(self):
# All strings are unicode in Python 3.
return torch._tensor_str._str(self)

def backward(self, gradient=None, retain_graph=None, create_graph=False):
def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None):
r"""Computes the gradient of current tensor w.r.t. graph leaves.

The graph is differentiated using the chain rule. If the tensor is
Expand Down Expand Up @@ -213,6 +213,11 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False):
create_graph (bool, optional): If ``True``, graph of the derivative will
be constructed, allowing to compute higher order derivative
products. Defaults to ``False``.
inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
accumulated into ``.grad``. All other Tensors will be ignored. If not
provided, the gradient is accumulated into all the leaf Tensors that were
used to compute the attr::tensors. All the provided inputs must be leaf
Tensors.
"""
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
Expand All @@ -223,8 +228,9 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False):
self,
gradient=gradient,
retain_graph=retain_graph,
create_graph=create_graph)
torch.autograd.backward(self, gradient, retain_graph, create_graph)
create_graph=create_graph,
inputs=inputs)
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

def register_hook(self, hook):
r"""Registers a backward hook.
Expand Down