Skip to content

Commit

Permalink
[DataLoader] Share seed via Distributed Store to get rid of CUDA depe…
Browse files Browse the repository at this point in the history
…ndency (#79829) (#79890)

Fixes #79828

In distributed environment, before this PR, DataLoader would create a Tensor holding the shared seed in RANK 0 and send the Tensor to other processes. However, when `NCCL` is used as the distributed backend, the Tensor is required to be moved to cuda before broadcasted from RANK 0 to other RANKs. And, this causes the Issue where DataLoader doesn't move the Tensor to cuda before sharing using `NCCL`.

After offline discussion with @mrshenli, we think the distributed Store is a better solution as the shared seed is just an integer value. Then, we can get rid of the dependency on NCCL and CUDA when sharing info between distributed processes for DataLoader.
Pull Request resolved: #79829
Approved by: https://github.com/VitalyFedyunin, https://github.com/NivekT
  • Loading branch information
ejguan committed Jun 21, 2022
1 parent 01d9324 commit 8186aa7
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions torch/utils/data/dataloader.py
Expand Up @@ -561,21 +561,28 @@ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):

def _get_shared_seed(self):
if isinstance(self.dataset, IterDataPipe):
_shared_tensor_seed = torch.empty((), dtype=torch.int64).random_(generator=self.generator)
_shared_seed = torch.empty((), dtype=torch.int64).random_(generator=self.generator).item()
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
ws = dist.get_world_size()
store = dist.distributed_c10d._get_default_store()
if rank == 0:
ws = dist.get_world_size()
reqs = []
for rank_id in range(1, ws):
req = dist.isend(tensor=_shared_tensor_seed, dst=rank_id, tag=rank_id)
reqs.append(req)
for req in reqs:
req.wait()
store.set("_dl_shared_seed", str(_shared_seed))
# Reset after all distributed processes have received the shared seed
store.add("_dl_shared_seed_recv_cnt", 1)
_shared_seed_recv_cnt = 1
while _shared_seed_recv_cnt != ws:
_shared_seed_recv_cnt = int(store.get("_dl_shared_seed_recv_cnt"))
store.set("_dl_shared_seed", "")
store.add("_dl_shared_seed_recv_cnt", -ws)
assert int(store.get("_dl_shared_seed_recv_cnt")) == 0
else:
dist.recv(tensor=_shared_tensor_seed, src=0, tag=rank)
_shared_seed = _shared_tensor_seed.item()
del _shared_tensor_seed
_shared_seed_str = ""
store.wait(["_dl_shared_seed"], _utils.MP_STATUS_CHECK_INTERVAL)
while len(_shared_seed_str) == 0:
_shared_seed_str = store.get("_dl_shared_seed")
store.add("_dl_shared_seed_recv_cnt", 1)
_shared_seed = int(_shared_seed_str)
return _shared_seed
else:
return None
Expand Down

0 comments on commit 8186aa7

Please sign in to comment.