diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 53541960f300..74942d1c77d8 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -117,8 +117,11 @@ Reducer::Reducer( // Map raw function pointer to replica index and parameter index. // This is used later on when the autograd graph is traversed - // to check for parameters for which no gradient is computed. - func_[grad_accumulator.get()] = index; + // to check for parameters for which no gradient is computed, if + // find_unused_parameters=True. + if (find_unused_parameters_) { + gradAccToVariableMap_[grad_accumulator.get()] = index; + } // The gradient accumulator is stored as weak_ptr in the autograd // metadata of the variable, so we have to keep it alive here for @@ -991,7 +994,7 @@ void Reducer::prepare_for_backward( } // Find accumulator functions that don't show up in this graph. - for (const auto& it : func_) { + for (const auto& it : gradAccToVariableMap_) { // If the accumulator function is present in the graph, we know // a gradient will be computed for the corresponding parameter. if (seen.count(it.first) > 0) { diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index efb2060a5533..29bdace7ce00 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -122,7 +122,8 @@ class Reducer { std::vector>> grad_accumulators_; - std::unordered_map func_; + std::unordered_map + gradAccToVariableMap_; std::vector>> hooks_;