Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only populate grad accumulator to var mapping for find_unused_parameters=True in DDP #45942

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions torch/csrc/distributed/c10d/reducer.cpp
Expand Up @@ -117,8 +117,11 @@ Reducer::Reducer(

// Map raw function pointer to replica index and parameter index.
// This is used later on when the autograd graph is traversed
// to check for parameters for which no gradient is computed.
func_[grad_accumulator.get()] = index;
// to check for parameters for which no gradient is computed, if
// find_unused_parameters=True.
if (find_unused_parameters_) {
gradAccToVariableMap_[grad_accumulator.get()] = index;
}

// The gradient accumulator is stored as weak_ptr in the autograd
// metadata of the variable, so we have to keep it alive here for
Expand Down Expand Up @@ -995,7 +998,7 @@ void Reducer::prepare_for_backward(
}

// Find accumulator functions that don't show up in this graph.
for (const auto& it : func_) {
for (const auto& it : gradAccToVariableMap_) {
// If the accumulator function is present in the graph, we know
// a gradient will be computed for the corresponding parameter.
if (seen.count(it.first) > 0) {
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/reducer.h
Expand Up @@ -122,7 +122,8 @@ class Reducer {

std::vector<std::vector<std::shared_ptr<torch::autograd::Node>>>
grad_accumulators_;
std::unordered_map<torch::autograd::Node*, VariableIndex> func_;
std::unordered_map<torch::autograd::Node*, VariableIndex>
gradAccToVariableMap_;
std::vector<std::pair<uintptr_t, std::shared_ptr<torch::autograd::Node>>>
hooks_;

Expand Down