Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix distributed store to use add for the counter of DL shared seed #80348

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions torch/utils/data/_utils/__init__.py
Expand Up @@ -34,6 +34,15 @@
https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
"""

DATAPIPE_SHARED_SEED = "_dl_shared_seed"
r"""The key to share the same seed for shuffle DataPipe across distributed processes"""

DATAPIPE_SHARED_SEED_COUNTER = "_dl_shared_seed_recv_cnt"
r"""The key to count the number of distributed processes that have received the shared seed"""

DATAPIPE_SHARED_SEED_CHECK_INTERVAL = 0.01
r"""Interval to check if each rank has received the shared seed"""


try:
import numpy
Expand Down
39 changes: 25 additions & 14 deletions torch/utils/data/dataloader.py
Expand Up @@ -5,12 +5,15 @@
in `./_utils/worker.py`.
"""

import functools
import itertools
import logging
import os
import queue
import threading
import itertools
import time
import warnings
import queue
import functools

from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional, Union

import multiprocessing as python_multiprocessing
Expand Down Expand Up @@ -63,6 +66,8 @@

get_worker_info = _utils.worker.get_worker_info

logger = logging.getLogger(__name__)


class _DatasetKind(object):
Map = 0
Expand Down Expand Up @@ -567,21 +572,27 @@ def _get_shared_seed(self):
ws = dist.get_world_size()
store = dist.distributed_c10d._get_default_store()
if rank == 0:
store.set("_dl_shared_seed", str(_shared_seed))
_shared_seed_str = str(_shared_seed)
store.set(_utils.DATAPIPE_SHARED_SEED, _shared_seed_str)
logger.info(f"Shared seed ({_shared_seed_str}) sent to store on rank 0")
# Use 'add' instead of 'get' since for some store implementations 'add'
# doesn't work well with 'get'.
_shared_seed_recv_cnt = store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, 1)
while _shared_seed_recv_cnt < ws:
time.sleep(_utils.DATAPIPE_SHARED_SEED_CHECK_INTERVAL)
_shared_seed_recv_cnt = store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, 0)
# Reset after all distributed processes have received the shared seed
store.add("_dl_shared_seed_recv_cnt", 1)
_shared_seed_recv_cnt = 1
while _shared_seed_recv_cnt != ws:
_shared_seed_recv_cnt = int(store.get("_dl_shared_seed_recv_cnt"))
store.set("_dl_shared_seed", "")
store.add("_dl_shared_seed_recv_cnt", -ws)
assert int(store.get("_dl_shared_seed_recv_cnt")) == 0
store.set(_utils.DATAPIPE_SHARED_SEED, "")
_shared_seed_recv_cnt = store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, -ws)
assert _shared_seed_recv_cnt == 0
else:
_shared_seed_str = ""
store.wait(["_dl_shared_seed"], _utils.MP_STATUS_CHECK_INTERVAL)
store.wait([_utils.DATAPIPE_SHARED_SEED], _utils.MP_STATUS_CHECK_INTERVAL)
while len(_shared_seed_str) == 0:
_shared_seed_str = store.get("_dl_shared_seed")
store.add("_dl_shared_seed_recv_cnt", 1)
time.sleep(_utils.DATAPIPE_SHARED_SEED_CHECK_INTERVAL)
_shared_seed_str = store.get(_utils.DATAPIPE_SHARED_SEED)
logger.info(f"Shared seed ({_shared_seed_str}) received from store on rank {rank}")
store.add(_utils.DATAPIPE_SHARED_SEED_COUNTER, 1)
_shared_seed = int(_shared_seed_str)
return _shared_seed
else:
Expand Down