New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Delete DDP hooks in Reducer destructor #21591
Conversation
Strictly speaking this can still cause interference between reducers. If you have a model where you alternate between 2 DDP wrapped modules of the same original module, and delete one of them, this will take down all the hooks, not just the hooks of one of them. I think we can easily keep the pointers somewhere and delete those, instead of nuking all of them of the same type. |
torch/csrc/autograd/function.h
Outdated
// delete all post hooks matching HookType | ||
template <typename HookType> | ||
void delete_post_hooks() { | ||
std::lock_guard<std::mutex> lock(this->mutex_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of the other accessors use a mutex. I suppose is fine if you have a single control thread that decides when to mutate structures like this one. I think either none of them or all of them should use a mutex.
If we have 2 DDP wrapping the same model, will that mess up the gradient hooks on Reducer construction? Both will insert their own set of hooks, which will be triggered by the backward call of either. |
torch/csrc/autograd/function.h
Outdated
@@ -325,6 +339,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> { | |||
std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_; | |||
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the right thing is to keep a copy of the raw pointer for every post-hook in reducer.h
, then walk through that list of post hook raw pointers on destruction of the Reducer
, and call a (to-be-written) delete_post_hook
method in function.h
that removes that specific hook from the array (O(n)). It's a little ugly to register a unique_ptr<>
and then keep a raw copy around, but you're only using the raw copy as a "key" for deletion, not to actually dereference, so it should be safe.
Now that I think about it, you can have add_post_hook
return a uintptr_t
"key" for the caller to store, so it's opaque to the caller and there's no weirdness. Then they just pass the uintptr_t
back in to delete_post_hook
and no abstractions are leaked or violated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch/csrc/distributed/c10d/init.cpp
Outdated
@@ -62,7 +62,8 @@ PyObject* c10d_init(PyObject* _unused) { | |||
[](::c10d::Reducer& reducer, const torch::autograd::Variable& output) | |||
-> void { reducer.prepare_for_backward({output}); }, | |||
py::call_guard<py::gil_scoped_release>()) | |||
.def("get_backward_stats", &::c10d::Reducer::get_backward_stats); | |||
.def("get_backward_stats", &::c10d::Reducer::get_backward_stats) | |||
.def("remove_all_hooks", &::c10d::Reducer::remove_all_hooks); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Python binding can be removed I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kutta asked for exposing this API, just in case Python GC does not deterministically kick in. @kuttas do you still need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pietern Even if we call __del__
on the DistributedDataParallel
object I am not totally sure that the C++ destructor (Reducer::~Reducer) will get called immediately or lazily (GC time). If this is 100% deterministic, then we don't need this call. Otherwise, we need to implement __del__
and explicitly clean up hooks in DDP. I searched pybind documentation and can't figure out what the expected behavior is if you don't have referential cycles - is ref count update immediate in python and does that immediately trigger the C++ destructor of the py-bound class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't matter if it is lazily deleted. The hooks may still be called, but only as long as the reducer object is still alive (per the destructor). If the prepare_for_backward
function isn't called by the DDP class that wraps the reducer, all the hooks will be nops.
Regarding referential cycles -- IIRC this is where the GC comes in. Objects can be destructed immediately if their refcount drops to zero and then there is a separate GC pass for the cycles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not super familiar with this part of the code but LGTM.
@pytorchbot rebase this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Summary: kuttas pointed out that the DDP Reducer only needs to remember `uintptr, Function` pairs, and hence does not need a nunordered map as added by #21591. Using a vector should speed it up a bit. Pull Request resolved: #21783 Differential Revision: D15854312 Pulled By: mrshenli fbshipit-source-id: 153ba035b8d658c7878a613f16a42de977d89c43
Summary: Pull Request resolved: #21914 #21591 added a needed feature to clean up grad accumulator post hooks when the DistributedDataParallel model object is cleaned up. There's a minor typo that causes it to loop infinitely over the first element. Differential Revision: D15878884 fbshipit-source-id: b7fd0bbd51eb187579d639b1709c6f7b62b85e7a
Closes #21344
DDP assigns the original module to the first module replica instead of creating a new one. Then, it creates a new Reducer to add post hooks to sync gradients. However, because every reconstructed DDP instance wraps the same original module, all their reducers will add hooks to the same set of variables. This PR deletes DDP hooks from variables when destructing Reducer, trying to make DDP failure recoverable.
@pietern @kuttas and I discussed the following solutions:
Solution 1
Keep
add_post_hook
API intact, and do adynamic_cast
indel_post_hook
to check hook type. If the type matches Reducer's hook, delete it. As @pietern mentioned, this will not work if we create multiple DDP instances from the same original model.Solution 2
Use a counter to generate a unique key for every hook in
Function
, and keep them in a map. return the key to the caller ofadd_post_hook
, and ask the caller to provide key if it needs to delete the hook.Con: this would add extra overhead to
add_post_hook
and everyFunction
object.Solution 3 [Current implementation]
@kuttas suggests that, instead of generating a unique key, directly using the address of the pointer would be better. In order to avoid messing up dereferencing, let
add_post_hook
to return auintptr_t
.