Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete DDP hooks in Reducer destructor #21591

Closed
wants to merge 7 commits into from
Closed

Conversation

mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Jun 10, 2019

Closes #21344

DDP assigns the original module to the first module replica instead of creating a new one. Then, it creates a new Reducer to add post hooks to sync gradients. However, because every reconstructed DDP instance wraps the same original module, all their reducers will add hooks to the same set of variables. This PR deletes DDP hooks from variables when destructing Reducer, trying to make DDP failure recoverable.

@pietern @kuttas and I discussed the following solutions:

Solution 1

Keep add_post_hook API intact, and do a dynamic_cast in del_post_hook to check hook type. If the type matches Reducer's hook, delete it. As @pietern mentioned, this will not work if we create multiple DDP instances from the same original model.

Solution 2

Use a counter to generate a unique key for every hook in Function, and keep them in a map. return the key to the caller of add_post_hook, and ask the caller to provide key if it needs to delete the hook.

Con: this would add extra overhead to add_post_hook and every Function object.

Solution 3 [Current implementation]

@kuttas suggests that, instead of generating a unique key, directly using the address of the pointer would be better. In order to avoid messing up dereferencing, let add_post_hook to return a uintptr_t.

@mrshenli mrshenli requested review from pietern and kuttas June 10, 2019 14:57
@mrshenli mrshenli requested a review from apaszke as a code owner June 10, 2019 14:57
@pytorchbot pytorchbot added module: autograd Related to torch.autograd, and the autograd engine in general oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jun 10, 2019
@pietern
Copy link
Contributor

pietern commented Jun 10, 2019

Strictly speaking this can still cause interference between reducers. If you have a model where you alternate between 2 DDP wrapped modules of the same original module, and delete one of them, this will take down all the hooks, not just the hooks of one of them. I think we can easily keep the pointers somewhere and delete those, instead of nuking all of them of the same type.

// delete all post hooks matching HookType
template <typename HookType>
void delete_post_hooks() {
std::lock_guard<std::mutex> lock(this->mutex_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the other accessors use a mutex. I suppose is fine if you have a single control thread that decides when to mutate structures like this one. I think either none of them or all of them should use a mutex.

@mrshenli
Copy link
Contributor Author

If you have a model where you alternate between 2 DDP wrapped modules of the same original module, and delete one of them, this will take down all the hooks, not just the hooks of one of them.

If we have 2 DDP wrapping the same model, will that mess up the gradient hooks on Reducer construction? Both will insert their own set of hooks, which will be triggered by the backward call of either.

@@ -325,6 +339,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
Copy link

@kuttas kuttas Jun 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the right thing is to keep a copy of the raw pointer for every post-hook in reducer.h, then walk through that list of post hook raw pointers on destruction of the Reducer, and call a (to-be-written) delete_post_hook method in function.h that removes that specific hook from the array (O(n)). It's a little ugly to register a unique_ptr<> and then keep a raw copy around, but you're only using the raw copy as a "key" for deletion, not to actually dereference, so it should be safe.
Now that I think about it, you can have add_post_hook return a uintptr_t "key" for the caller to store, so it's opaque to the caller and there's no weirdness. Then they just pass the uintptr_t back in to delete_post_hook and no abstractions are leaked or violated.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@pytorchbot pytorchbot added the module: ci Related to continuous integration label Jun 10, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@pietern pietern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM aside from the one pybind11 nit.

@apaszke @albanD Thoughts on the autograd function functions?

@@ -62,7 +62,8 @@ PyObject* c10d_init(PyObject* _unused) {
[](::c10d::Reducer& reducer, const torch::autograd::Variable& output)
-> void { reducer.prepare_for_backward({output}); },
py::call_guard<py::gil_scoped_release>())
.def("get_backward_stats", &::c10d::Reducer::get_backward_stats);
.def("get_backward_stats", &::c10d::Reducer::get_backward_stats)
.def("remove_all_hooks", &::c10d::Reducer::remove_all_hooks);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Python binding can be removed I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kutta asked for exposing this API, just in case Python GC does not deterministically kick in. @kuttas do you still need this?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pietern Even if we call __del__ on the DistributedDataParallel object I am not totally sure that the C++ destructor (Reducer::~Reducer) will get called immediately or lazily (GC time). If this is 100% deterministic, then we don't need this call. Otherwise, we need to implement __del__ and explicitly clean up hooks in DDP. I searched pybind documentation and can't figure out what the expected behavior is if you don't have referential cycles - is ref count update immediate in python and does that immediately trigger the C++ destructor of the py-bound class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't matter if it is lazily deleted. The hooks may still be called, but only as long as the reducer object is still alive (per the destructor). If the prepare_for_backward function isn't called by the DDP class that wraps the reducer, all the hooks will be nops.

Regarding referential cycles -- IIRC this is where the GC comes in. Objects can be destructed immediately if their refcount drops to zero and then there is a separate GC pass for the cycles.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super familiar with this part of the code but LGTM.

@mrshenli
Copy link
Contributor Author

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@pietern pietern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@facebook-github-bot
Copy link
Contributor

@mrshenli merged this pull request in cbcb2b5.

facebook-github-bot pushed a commit that referenced this pull request Jun 17, 2019
Summary:
kuttas pointed out that the DDP Reducer only needs to remember `uintptr, Function` pairs, and hence does not need a nunordered map as added by #21591. Using a vector should speed it up a bit.
Pull Request resolved: #21783

Differential Revision: D15854312

Pulled By: mrshenli

fbshipit-source-id: 153ba035b8d658c7878a613f16a42de977d89c43
facebook-github-bot pushed a commit that referenced this pull request Jun 18, 2019
Summary:
Pull Request resolved: #21914

#21591 added a needed feature to clean up grad accumulator post hooks when the DistributedDataParallel model object is cleaned up. There's a minor typo that causes it to loop infinitely over the first element.

Differential Revision: D15878884

fbshipit-source-id: b7fd0bbd51eb187579d639b1709c6f7b62b85e7a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: autograd Related to torch.autograd, and the autograd engine in general module: ci Related to continuous integration oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make DDP failure recoverable
7 participants