-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add inputs argument to autograd.backward()
#46855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
64eac01
fcf409b
a20f7e2
3e5e76b
4669d01
069abe6
f05bfb4
5a7d9cc
8f803a7
394bc93
35e24df
2518abb
3056fa5
69805e9
cf224f3
d0e54b9
04316bd
620de90
30b3516
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,7 @@ variable_list grad( | |
| grad_outputs, | ||
| true, | ||
| false, | ||
| false, | ||
| fmap(inputs, get_edge)); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,13 +87,14 @@ variable_list PythonEngine::execute( | |
| const variable_list& inputs, | ||
| bool keep_graph, | ||
| bool create_graph, | ||
| bool accumulate_grad, | ||
| const edge_list& outputs) { | ||
| TORCH_CHECK(!PyGILState_Check(), "The autograd engine was called while holding the GIL. If you are using the C++ " | ||
| "API, the autograd engine is an expensive operation that does not require the " | ||
| "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'" | ||
| ". If you are not using the C++ API, please report a bug to the pytorch team.") | ||
| try { | ||
| return Engine::execute(roots, inputs, keep_graph, create_graph, outputs); | ||
| return Engine::execute(roots, inputs, keep_graph, create_graph, accumulate_grad, outputs); | ||
| } catch (python_error& e) { | ||
| e.restore(); | ||
| throw; | ||
|
|
@@ -128,14 +129,14 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg | |
| unsigned char create_graph = 0; | ||
| PyObject *inputs = nullptr; | ||
| unsigned char allow_unreachable = 0; | ||
| const char *accepted_kwargs[] = { | ||
| unsigned char accumulate_grad = 0; | ||
| const char *accepted_kwargs[] = { // NOLINT | ||
| "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs", | ||
| "allow_unreachable", nullptr | ||
| "allow_unreachable", "accumulate_grad", nullptr | ||
| }; | ||
| if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs, | ||
| &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable)) | ||
| if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Obb", (char**)accepted_kwargs, | ||
| &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable, &accumulate_grad)) | ||
| return nullptr; | ||
|
|
||
| THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to " | ||
| "be a tuple, but got %s", THPUtils_typename(tensors)); | ||
| THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is " | ||
|
|
@@ -147,7 +148,7 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg | |
| "gradients", num_tensors, num_gradients); | ||
|
|
||
| // The user either called autograd.backward(...) or autograd.grad(...) to get here | ||
| bool backward_api_called = inputs == nullptr; | ||
| bool backward_api_called = accumulate_grad; | ||
| TORCH_CHECK(!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @zou3519 just to confirm that this is fine from a vmap point of view? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply - |
||
| "backward() called inside torch.vmap. This is not supported, " | ||
| "please call backward() outside torch.vmap or instead use " | ||
|
|
@@ -193,7 +194,7 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg | |
| } | ||
|
|
||
| std::vector<Edge> output_edges; | ||
| if (!backward_api_called) { | ||
| if (inputs != nullptr) { | ||
| int num_inputs = PyTuple_GET_SIZE(inputs); | ||
| output_edges.reserve(num_inputs); | ||
| for (int i = 0; i < num_inputs; ++i) { | ||
|
|
@@ -210,7 +211,11 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg | |
| const auto output_nr = input_var->cdata.output_nr(); | ||
| auto grad_fn = input_var->cdata.grad_fn(); | ||
| if (!grad_fn) { | ||
| grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata); | ||
| grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata); | ||
| } | ||
| if (accumulate_grad) { | ||
| THPUtils_assert(input_var->cdata.is_leaf(), | ||
| "One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor"); | ||
| } | ||
| THPUtils_assert(input_var->cdata.requires_grad(), | ||
| "One of the differentiated Tensors does not require grad"); | ||
|
|
@@ -226,10 +231,10 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg | |
| { | ||
| pybind11::gil_scoped_release no_gil; | ||
| auto& engine = python::PythonEngine::get_python_engine(); | ||
| outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges); | ||
| outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges); | ||
| } | ||
|
|
||
| if (!backward_api_called) { | ||
| if (!backward_api_called && inputs != nullptr) { | ||
| int num_inputs = PyTuple_GET_SIZE(inputs); | ||
| THPObjectPtr py_outputs {PyTuple_New(num_inputs)}; | ||
| if (!py_outputs) return nullptr; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ struct PythonEngine : public Engine { | |
| const variable_list& inputs, | ||
| bool keep_graph, | ||
| bool create_graph, | ||
| bool accumulate_grad, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add a comment here for what accumulate_grad does, for future code readers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating a separate PR for this |
||
| const edge_list& outputs = {}) override; | ||
|
|
||
| std::shared_ptr<at::ivalue::Future> execute_with_graph_task( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.