Skip to content

Commit 129156b

Browse files
committed
Add guard for non-default stream in DDP's autograd engine callback
Pull Request resolved: #40115 Closes #37790 Closes #37944 A user may wish to run DDP's forward + backwards step under a non-default CUDA stream such as those created by `with torch.cuda.Stream(stream)`. In this case, the user should be responsible for synchronizing events on this stream with other streams used in the program (per the documentation at https://pytorch.org/docs/stable/notes/cuda.html#cuda-semantics), but currently DDP has a bug which causes DDP under non-default streams to fail. If a user does the following: ``` model = DDP(...) loss = model(inptut).sum() loss.backward() grad = model.module.weight.grad() average = dist.all_reduce(grad) ``` There is a chance that `average` and `grad` will not be equal. This is because the CUDA kernels corresponding to the `all_reduce` call may run before `loss.backward()`'s kernels are finished. Specifically, in DDP we copy the allreduced gradients back to the model parameter gradients in an autograd engine callback, but this callback runs on the default stream. Note that this can also be fixed by the application synchronizing on the current stream, although this should not be expected, since the application is not using the current stream at all. This PR fixes the issue by passing the current stream into DDP's callback. Tested by adding a UT `test_DistributedDataParallel_non_default_stream` that fails without this PR ghstack-source-id: 106481208 Differential Revision: [D22073353](https://our.internmc.facebook.com/intern/diff/D22073353/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22073353/)!
1 parent cbd53bf commit 129156b

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

test/distributed/test_distributed.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,45 @@ def test_DistributedDataParallel_requires_grad(self):
19061906
# a module without gradients shouldn't be accepted
19071907
self.assertRaises(AssertionError, lambda: nn.parallel.DistributedDataParallel(nn.Module()))
19081908

1909+
@unittest.skipIf(
1910+
BACKEND != "nccl" and BACKEND != "gloo",
1911+
"Only NCCL and GLOO backend support DistributedDataParallel",
1912+
)
1913+
@skip_if_lt_x_gpu(2)
1914+
@skip_if_rocm
1915+
def test_DistributedDataParallel_non_default_stream(self):
1916+
stream = torch.cuda.Stream()
1917+
rank = self.rank
1918+
with torch.cuda.stream(stream):
1919+
net = torch.nn.parallel.DistributedDataParallel(
1920+
torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank]
1921+
)
1922+
for i in range(1000):
1923+
# Clear gradients manually
1924+
grad = net.module.weight.grad
1925+
if grad is not None:
1926+
grad.detach_()
1927+
grad.zero_()
1928+
# Forward + BW
1929+
batch = torch.tensor([rank]).float().cuda(rank)
1930+
loss = net(batch).sum()
1931+
loss.backward()
1932+
# For each worker, the gradient on the weight should be worker_rank.
1933+
grad = net.module.weight.grad
1934+
avg = grad.clone()
1935+
# All-reducing the gradient averages should give us the gradient
1936+
# average. If not, then one of the workers has not correctly
1937+
# written back the averaged gradient before this all-reduce call.
1938+
dist.all_reduce(avg)
1939+
world_size = int(os.environ["WORLD_SIZE"])
1940+
avg.div_(world_size)
1941+
expected_grad = sum(i for i in range(world_size)) / world_size
1942+
self.assertEqual(
1943+
avg[0, 0],
1944+
expected_grad,
1945+
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
1946+
)
1947+
19091948
@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
19101949
"Only Nccl & Gloo backend support DistributedDataParallel")
19111950
@skip_if_no_gpu

torch/csrc/distributed/c10d/reducer.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <functional>
44

55
#include <c10/core/DeviceGuard.h>
6+
#include <c10/core/StreamGuard.h>
67
#include <c10/util/Exception.h>
78
#include <torch/csrc/autograd/engine.h>
89
#include <torch/csrc/autograd/function_hook.h>
@@ -483,8 +484,17 @@ void Reducer::mark_variable_ready(VariableIndex index) {
483484
}
484485
local_used_work_ = process_group_->allreduce(local_used_maps_dev_);
485486

487+
// The autograd engine uses the default stream when running callbacks, so we
488+
// pass in the current CUDA stream in case it is not the default.
489+
c10::DeviceType deviceType = replica.contents.device().type();
490+
const c10::impl::VirtualGuardImpl guard =
491+
c10::impl::VirtualGuardImpl{deviceType};
492+
const c10::Stream currentStream =
493+
guard.getStream(replica.contents.device());
486494
torch::autograd::Engine::get_default_engine().queue_callback([=] {
487495
std::unique_lock<std::mutex> lock(this->mutex_);
496+
// Run callback with the current stream
497+
c10::OptionalStreamGuard currentStreamGuard{currentStream};
488498
this->finalize_backward();
489499
// Rebuild bucket if this is the first time to rebuild
490500
if (!rebuilt_params_.empty()) {

0 commit comments

Comments
 (0)