Skip to content

Commit

Permalink
Fix distributed store to use add for the counter of DL shared seed
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Jun 27, 2022
1 parent 80b50df commit 0de7982
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions torch/utils/data/dataloader.py
Expand Up @@ -5,12 +5,15 @@
in `./_utils/worker.py`.
"""

import functools
import itertools
import logging
import os
import queue
import threading
import itertools
import time
import warnings
import queue
import functools

from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional, Union

import multiprocessing as python_multiprocessing
Expand Down Expand Up @@ -63,6 +66,8 @@

get_worker_info = _utils.worker.get_worker_info

logger = logging.getLogger(__name__)


class _DatasetKind(object):
Map = 0
Expand Down Expand Up @@ -567,20 +572,26 @@ def _get_shared_seed(self):
ws = dist.get_world_size()
store = dist.distributed_c10d._get_default_store()
if rank == 0:
store.set("_dl_shared_seed", str(_shared_seed))
_shared_seed_str = str(_shared_seed)
store.set("_dl_shared_seed", _shared_seed_str)
logger.info(f"Shared seed ({_shared_seed_str}) sent to store on rank 0")
# Use 'add' instead of 'get' since for some store implementations 'add'
# doesn't work well with 'get'.
_shared_seed_recv_cnt = store.add("_dl_shared_seed_recv_cnt", 1)
while _shared_seed_recv_cnt < ws:
time.sleep(0.01)
_shared_seed_recv_cnt = store.add("_dl_shared_seed_recv_cnt", 0)
# Reset after all distributed processes have received the shared seed
store.add("_dl_shared_seed_recv_cnt", 1)
_shared_seed_recv_cnt = 1
while _shared_seed_recv_cnt != ws:
_shared_seed_recv_cnt = int(store.get("_dl_shared_seed_recv_cnt"))
store.set("_dl_shared_seed", "")
store.add("_dl_shared_seed_recv_cnt", -ws)
assert int(store.get("_dl_shared_seed_recv_cnt")) == 0
_shared_seed_recv_cnt = store.add("_dl_shared_seed_recv_cnt", -ws)
assert _shared_seed_recv_cnt == 0
else:
_shared_seed_str = ""
store.wait(["_dl_shared_seed"], _utils.MP_STATUS_CHECK_INTERVAL)
while len(_shared_seed_str) == 0:
time.sleep(0.01)
_shared_seed_str = store.get("_dl_shared_seed")
logger.info(f"Shared seed ({_shared_seed_str}) received from store on rank {rank}")
store.add("_dl_shared_seed_recv_cnt", 1)
_shared_seed = int(_shared_seed_str)
return _shared_seed
Expand Down

0 comments on commit 0de7982

Please sign in to comment.