Skip to content

Commit

Permalink
Add guard for non-default stream in DDP's autograd engine callback (#…
Browse files Browse the repository at this point in the history
…40115)

Summary:
Pull Request resolved: #40115

Closes #37790
Closes #37944

A user may wish to run DDP's forward + backwards step under a non-default CUDA stream such as those created by `with torch.cuda.Stream(stream)`. In this case, the user should be responsible for synchronizing events on this stream with other streams used in the program (per the documentation at https://pytorch.org/docs/stable/notes/cuda.html#cuda-semantics), but currently DDP has a bug which causes DDP under non-default streams to fail.

If a user does the following:
```
model = DDP(...)
loss = model(inptut).sum()
loss.backward()
grad = model.module.weight.grad()
average = dist.all_reduce(grad)
```

There is a chance that `average` and `grad` will not be equal. This is because the CUDA kernels corresponding to the  `all_reduce` call may run before `loss.backward()`'s kernels are finished. Specifically, in DDP we copy the allreduced gradients back to the model parameter gradients in an autograd engine callback, but this callback runs on the default stream. Note that this can also be fixed by the application synchronizing on the current stream, although this should not be expected, since the application is not using the current stream at all.

This PR fixes the issue by passing the current stream into DDP's callback.

Tested by adding a UT `test_DistributedDataParallel_non_default_stream` that fails without this PR
ghstack-source-id: 106481208

Differential Revision: D22073353

fbshipit-source-id: 70da9b44e5f546ff8b6d8c42022ecc846dff033e
  • Loading branch information
rohan-varma committed Jul 8, 2020
1 parent af9600b commit dec846f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
39 changes: 39 additions & 0 deletions test/distributed/test_distributed.py
Expand Up @@ -1906,6 +1906,45 @@ def test_DistributedDataParallel_requires_grad(self):
# a module without gradients shouldn't be accepted
self.assertRaises(AssertionError, lambda: nn.parallel.DistributedDataParallel(nn.Module()))

@unittest.skipIf(
BACKEND != "nccl" and BACKEND != "gloo",
"Only NCCL and GLOO backend support DistributedDataParallel",
)
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_DistributedDataParallel_non_default_stream(self):
stream = torch.cuda.Stream()
rank = self.rank
with torch.cuda.stream(stream):
net = torch.nn.parallel.DistributedDataParallel(
torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank]
)
for i in range(1000):
# Clear gradients manually
grad = net.module.weight.grad
if grad is not None:
grad.detach_()
grad.zero_()
# Forward + BW
batch = torch.tensor([rank]).float().cuda(rank)
loss = net(batch).sum()
loss.backward()
# For each worker, the gradient on the weight should be worker_rank.
grad = net.module.weight.grad
avg = grad.clone()
# All-reducing the gradient averages should give us the gradient
# average. If not, then one of the workers has not correctly
# written back the averaged gradient before this all-reduce call.
dist.all_reduce(avg)
world_size = int(os.environ["WORLD_SIZE"])
avg.div_(world_size)
expected_grad = sum(i for i in range(world_size)) / world_size
self.assertEqual(
avg[0, 0],
expected_grad,
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
)

@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
"Only Nccl & Gloo backend support DistributedDataParallel")
@skip_if_no_gpu
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/distributed/c10d/reducer.cpp
Expand Up @@ -3,6 +3,7 @@
#include <functional>

#include <c10/core/DeviceGuard.h>
#include <c10/core/StreamGuard.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function_hook.h>
Expand Down Expand Up @@ -483,8 +484,17 @@ void Reducer::mark_variable_ready(VariableIndex index) {
}
local_used_work_ = process_group_->allreduce(local_used_maps_dev_);

// The autograd engine uses the default stream when running callbacks, so we
// pass in the current CUDA stream in case it is not the default.
c10::DeviceType deviceType = replica.contents.device().type();
const c10::impl::VirtualGuardImpl guard =
c10::impl::VirtualGuardImpl{deviceType};
const c10::Stream currentStream =
guard.getStream(replica.contents.device());
torch::autograd::Engine::get_default_engine().queue_callback([=] {
std::unique_lock<std::mutex> lock(this->mutex_);
// Run callback with the current stream
c10::OptionalStreamGuard currentStreamGuard{currentStream};
this->finalize_backward();
// Rebuild bucket if this is the first time to rebuild
if (!rebuilt_params_.empty()) {
Expand Down

0 comments on commit dec846f

Please sign in to comment.