-
Notifications
You must be signed in to change notification settings - Fork 25k
Use std::shared_ptr for DistAutogradContext. #29770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use std::shared_ptr for DistAutogradContext. #29770
Conversation
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) [ghstack-poisoned]
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) ghstack-source-id: 93873107 Pull Request resolved: #29770
t.start() | ||
|
||
with self.assertRaisesRegex(RuntimeError, "Could not find autograd context with id"): | ||
dist_autograd.backward([t1.sum()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% clear on what behavior is being tested here and what should be different than from before this PR. Does this result in the backward call throwing an exception (if so, what is exactly is the exception that should be thrown)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception being thrown is "Could not find autograd context id for". This happens because the autograd context is cleaned up and some thread on some node is still looking for that autograd context.
Before this PR, this test would cause the process to crash since a thread would be using a reference to DistAutogradContext
but our clean up logic would delete the context from the container.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this assume that the thread cleared the context before the backward uses it? How do we guarantee that?
test/dist_autograd_test.py
Outdated
# ensures we simulate a case where we clean up the context while the | ||
# backward pass is running. | ||
while not DistAutogradTest._my_backward_func_executed: | ||
time.sleep(0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes the backward pass takes longer than 0.1s. Potential flakiness.
class DistAutogradContext; | ||
using ContextPtr = std::shared_ptr<DistAutogradContext>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defined twice? Also in dist_autograd_container.h
.
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) [ghstack-poisoned]
Pull Request resolved: #29770 We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 ghstack-source-id: 93879405 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing, @pritamdamania87. One more pedantic concern.
test/dist_autograd_test.py
Outdated
# backward pass is running. | ||
with DistAutogradTest._my_backward_func_executed: | ||
DistAutogradTest._my_backward_func_executed.wait() | ||
dist_autograd._release_context(context._context_id()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that this executes after the backwards pass has completed? It's a long shot, seeing as this thread would have to be stalled for at least another 50 autograd steps, but calling it out for the sake of robustness.
Would it be possible to simply call release_context
from the backward
function itself?
Otherwise, adding another condition to mark completion of the release would make this always work. In the current setting, you could insert a sleep before the release_context
and make the test fail, which means that unfortunate thread scheduling could in theory make the test fail.
@@ -67,7 +67,7 @@ class TORCH_API DistEngine { | |||
// We also determine all leaf nodes(functions) in the graph and accumulate | |||
// them in outputEdges. | |||
void computeDependencies( | |||
DistAutogradContext& context, | |||
ContextPtr context, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a reference of the shared_ptr, as the call site always holds a copy of the shared_ptr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use a const ref.
A const ref to a shared_ptr
doesn't mean the wrapped object is const.
t.start() | ||
|
||
with self.assertRaisesRegex(RuntimeError, "Could not find autograd context with id"): | ||
dist_autograd.backward([t1.sum()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this assume that the thread cleared the context before the backward uses it? How do we guarantee that?
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) [ghstack-poisoned]
Pull Request resolved: #29770 We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 ghstack-source-id: 93942754 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/)
@pietern @mrshenli Made a few changes to ensure there is no race in the unit test. This was a little tricky since |
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) [ghstack-poisoned]
Pull Request resolved: #29770 We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 ghstack-source-id: 94036381 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/)
test/dist_autograd_test.py
Outdated
def backward(ctx, input): | ||
assert(DistAutogradTest._test_clean_context_backward_context_id is not None) | ||
|
||
# Release the context to simulate error (use barrier before releasing context to ensure all nodes execute the backward function). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment line too long
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the convention for python line length in OSS? The python linter isn't complaining and I recall internally the line length is 150 chars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not aware if there is any convention here, I personally use 100. Feel free to ignore this comment though if all linters do not complain
test/dist_autograd_test.py
Outdated
# Send the context id to all nodes. | ||
for i in range(0, self.world_size): | ||
if i != self.rank: | ||
rpc.rpc_sync("worker{}".format(i), _set_rpc_done, args=(context_id, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the rank distance is always 1 for _set_rpc_done here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't matter for this test case, will fix it though.
test/dist_autograd_test.py
Outdated
for i in range(0, 100): | ||
dst = self._next_rank() | ||
t1 = rpc.rpc_sync("worker{}".format(dst), torch.add, args=(t1, t1)) | ||
if i == 99: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just do this outside of this loop instead of having a if clause here?
|
||
# Release the context to simulate error (use barrier before releasing context to ensure all nodes execute the backward function). | ||
dist.barrier() | ||
dist_autograd._release_context(DistAutogradTest._test_clean_context_backward_context_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this blocking? It has to be to guarantee correctness, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not blocking, it releases the local context and then just sends async RPCs to release other contexts. The method below _all_contexts_cleaned_up
is blocking and ensures that contexts are cleaned up on all nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes LGTM!
dist_autograd.backward([t1.sum()]) | ||
|
||
# HACK: Killing workers since otherwise the autograd engine gets stuck on | ||
# other nodes. The proper fix would be addressing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be more specific on why it might stuck on others.
Would I be correct if I assume it stuck because the crashing backward destroyed the context on this node, and hence the next_rank
won't be able to clear the context when exiting the scope.
@@ -20,7 +21,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node { | |||
public: | |||
explicit RecvRpcBackward( | |||
const AutogradMetadata& autogradMetadata, | |||
DistAutogradContext& autogradContext, | |||
std::shared_ptr<DistAutogradContext> autogradContext, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is not using ContextPtr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would have a circular dependency, thats why we need to have a forward declaration for DistAutogradContext
.
@@ -30,7 +31,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node { | |||
const AutogradMetadata autogradMetadata_; | |||
|
|||
// Hold a reference to the autograd context. | |||
DistAutogradContext& autogradContext_; | |||
std::shared_ptr<DistAutogradContext> autogradContext_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) [ghstack-poisoned]
Pull Request resolved: #29770 We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 ghstack-source-id: 94159438 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) [ghstack-poisoned]
Pull Request resolved: #29770 We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 ghstack-source-id: 94176890 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/)
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) [ghstack-poisoned]
We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New commit triggers build failures. Will drop "request changes" when build and test pass.
Pull Request resolved: #29770 We were passing around const and non-const references for DistAutogradContext from DistAutogradContainer. This wasn't safe since the context could be deleted from the container and a thread might still be using the reference. This usually would happen when a backward pass fails on the node driving the backward pass (resulting in delete context messages being sent to all nodes) but other nodes are still executing code related to that autograd context. This was also the reason why `test_backward_autograd_engine_error` was flaky. Using a std::shared_ptr everywhere ensures we're safe and never crash. Closes #28928 Closes #26922 ghstack-source-id: 94201446 Differential Revision: [D18494814](https://our.internmc.facebook.com/intern/diff/D18494814/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests look OK now. Created #30110 to track the failed test, which is irrelevant to this PR.
This pull request has been merged in 63c957c. |
Stack from ghstack:
We were passing around const and non-const references for
DistAutogradContext from DistAutogradContainer. This wasn't safe since the
context could be deleted from the container and a thread might still be using
the reference. This usually would happen when a backward pass fails on the node
driving the backward pass (resulting in delete context messages being sent to
all nodes) but other nodes are still executing code related to that autograd
context.
This was also the reason why
test_backward_autograd_engine_error
was flaky.Using a std::shared_ptr everywhere ensures we're safe and never crash.
Closes #28928
Closes #26922
Differential Revision: D18494814