Skip to content

Commit

Permalink
Prioritize raising error message about unused parameters when rebuild…
Browse files Browse the repository at this point in the history
…_buckets fails (#45933)

Summary:
Pull Request resolved: #45933

Occasionally users run DDP with models with unused params, in this
case we would like to surface an error message telling them to run with
find_unused_params=True. However, a recent change to rebuild_buckets logic (#44798) made
it so that we raise a size mismatch error when this happens, but the
information about unused parameters is likely to be more useful and likely to
be the most common case of failure. Prefer raising this error over the
subsequent size mismatch errors.
ghstack-source-id: 113914759

Test Plan: Added unittest

Reviewed By: mrshenli

Differential Revision: D24151256

fbshipit-source-id: 5d349a988b4aac7d3e0ef7b3cd84dfdcbe9db675
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Oct 9, 2020
1 parent 9fb8e33 commit 62554a3
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 25 deletions.
58 changes: 33 additions & 25 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,31 +946,6 @@ void Reducer::prepare_for_backward(
std::unordered_set<torch::autograd::Node*> seen;
std::vector<torch::autograd::Node*> queue;

// Check that any prior reduction has finished.
// The variable `require_finalize_` is true until all gradients
// have been computed and reduction of all buckets has been kicked off.
if (require_finalize_) {
TORCH_CHECK(
false,
"Expected to have finished reduction in the prior iteration before ",
"starting a new one. ",
"",
"This error indicates that your module has parameters that were ",
"not used in producing loss. ",
"",
"You can enable unused parameter detection by (1) passing the keyword "
"argument `find_unused_parameters=True` to ",
"`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ",
"`forward` function outputs participate in calculating loss. "
"",
"If you already have done the above two steps, then the distributed ",
"data parallel module wasn't able to locate the output tensors in the ",
"return value of your module's `forward` function. ",
"Please include the loss function and the structure of the return ",
"value of `forward` of your module when reporting this issue (e.g. ",
"list, dict, iterable).");
}

// Reset accounting.
expect_autograd_hooks_ = true;
next_bucket_ = 0;
Expand Down Expand Up @@ -1325,6 +1300,11 @@ void Reducer::sync_bucket_indices(
}

bool Reducer::rebuild_buckets() {
// Ensure reduction for previous backwards pass is finished. If user's model
// has unused parameters for example, this will raise an error recommending to
// run with find_unused_parameters=True, instead of the size mismatch
// exception below.
ensure_prior_reduction_finished();
std::lock_guard<std::mutex> lock(mutex_);
if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
return false;
Expand Down Expand Up @@ -1381,6 +1361,34 @@ void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
comm_hook_ = std::move(iface);
}

void Reducer::ensure_prior_reduction_finished() {
// Check that any prior reduction has finished.
// The variable `require_finalize_` is true until all gradients
// have been computed and reduction of all buckets has been kicked off.
if (require_finalize_) {
TORCH_CHECK(
false,
"Expected to have finished reduction in the prior iteration before ",
"starting a new one. ",
"",
"This error indicates that your module has parameters that were ",
"not used in producing loss. ",
"",
"You can enable unused parameter detection by (1) passing the keyword "
"argument `find_unused_parameters=True` to ",
"`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ",
"`forward` function outputs participate in calculating loss. "
"",
"If you already have done the above two steps, then the distributed ",
"data parallel module wasn't able to locate the output tensors in the ",
"return value of your module's `forward` function. ",
"Please include the loss function and the structure of the return ",
"value of `forward` of your module when reporting this issue (e.g. ",
"list, dict, iterable).");
}

}

namespace {

// Tensors may be coalesced into buckets. Buckets must contain tensors of
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/distributed/c10d/reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ class Reducer {

void finalize_backward();

// Asserts that the reduction for the previous iteration has finished before
// rebuilding buckets or kicking off the next one.
void ensure_prior_reduction_finished();

// Broadcast rebuilt buckets from rank 0 to other ranks before initializing
// the buckets
void sync_bucket_indices(std::vector<std::vector<size_t>>& bucket_indices);
Expand Down
31 changes: 31 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3662,6 +3662,37 @@ def forward(self, x):
# isolate failure hangs.
torch.cuda.synchronize(device=self.rank)

@require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_unused_params_rebuild_buckets_exception(self):
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10, bias=False)
self.net2 = nn.Linear(10, 10, bias=False)

def forward(self, x):
return self.net1(x)

ddp = torch.nn.parallel.DistributedDataParallel(
ToyModel().cuda(self.rank), device_ids=[self.rank]
)
for i in range(2):
inp = torch.rand(1, 10)
if i > 0:
# On 2nd iteration, this will fail during rebuild_buckets,
# but we should report an error regarding unused parameters
# since that is the underlying root cause.
with self.assertRaisesRegex(
RuntimeError,
"Expected to have finished reduction in the prior iteration",
):
ddp(inp).sum().backward()
else:
ddp(inp).sum().backward()

@require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
Expand Down

0 comments on commit 62554a3

Please sign in to comment.