Skip to content

Commit

Permalink
Only populate grad accumulator to var mapping for find_unused_paramet…
Browse files Browse the repository at this point in the history
…ers=True in DDP (#45942)

Summary:
Pull Request resolved: #45942

We only need to keep track of this for traversing the autograd graph
when find_unused_parameters=True. Without that, we populate and keep this
mapping in memory, which occupies sizeof(pointer) * number of grad accumulators
of extra memory.
ghstack-source-id: 114219289

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D24154407

fbshipit-source-id: 220d723e262f36590a03a3fd2dab47cbfdb87d40
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Oct 14, 2020
1 parent 31bcd96 commit f739875
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 6 additions & 3 deletions torch/csrc/distributed/c10d/reducer.cpp
Expand Up @@ -117,8 +117,11 @@ Reducer::Reducer(

// Map raw function pointer to replica index and parameter index.
// This is used later on when the autograd graph is traversed
// to check for parameters for which no gradient is computed.
func_[grad_accumulator.get()] = index;
// to check for parameters for which no gradient is computed, if
// find_unused_parameters=True.
if (find_unused_parameters_) {
gradAccToVariableMap_[grad_accumulator.get()] = index;
}

// The gradient accumulator is stored as weak_ptr in the autograd
// metadata of the variable, so we have to keep it alive here for
Expand Down Expand Up @@ -991,7 +994,7 @@ void Reducer::prepare_for_backward(
}

// Find accumulator functions that don't show up in this graph.
for (const auto& it : func_) {
for (const auto& it : gradAccToVariableMap_) {
// If the accumulator function is present in the graph, we know
// a gradient will be computed for the corresponding parameter.
if (seen.count(it.first) > 0) {
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/reducer.h
Expand Up @@ -122,7 +122,8 @@ class Reducer {

std::vector<std::vector<std::shared_ptr<torch::autograd::Node>>>
grad_accumulators_;
std::unordered_map<torch::autograd::Node*, VariableIndex> func_;
std::unordered_map<torch::autograd::Node*, VariableIndex>
gradAccToVariableMap_;
std::vector<std::pair<uintptr_t, std::shared_ptr<torch::autograd::Node>>>
hooks_;

Expand Down

0 comments on commit f739875

Please sign in to comment.