Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix store based barrier to only use 'add'. #49930

Closed
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 18 additions & 25 deletions torch/distributed/distributed_c10d.py
@@ -1,8 +1,8 @@
import contextlib
import logging
import pickle
import torch
import warnings
import contextlib
import sys
import time
from torch._six import string_classes
from datetime import timedelta
Expand All @@ -18,7 +18,6 @@
AllreduceCoalescedOptions,
AllToAllOptions,
BroadcastOptions,
FileStore,
GatherOptions,
PrefixStore,
ProcessGroup,
Expand All @@ -27,15 +26,8 @@
ReduceScatterOptions,
ScatterOptions,
Store,
TCPStore,
)

if sys.platform != 'win32':
from torch._C._distributed_c10d import (
HashStore,
)


_MPI_AVAILABLE = True
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
Expand Down Expand Up @@ -191,16 +183,25 @@ def _store_based_barrier(rank, store, timeout):
"""
store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count)
store.add(store_key, 1)
logging.info('Added key: {} to store for rank: {}'.format(store_key, rank))

# Now wait for all workers to check in with the store.
world_size = get_world_size()
worker_count = int(store.get(store_key))
# Use 'add' instead of 'get' since for some store implementations 'add'
# doesn't work well with 'get'. Ideally the store implementations should
# be fixed, but for backward compatiblity reasons it is risky to change
# the store implementations. Once, we completely migrate away from these
# legacy stores, we can use 'get' here instead.
pritamdamania87 marked this conversation as resolved.
Show resolved Hide resolved
worker_count = store.add(store_key, 0)
start = time.time()
while worker_count != world_size:
time.sleep(0.01)
worker_count = int(store.get(store_key))
worker_count = store.add(store_key, 0)
pritamdamania87 marked this conversation as resolved.
Show resolved Hide resolved
pritamdamania87 marked this conversation as resolved.
Show resolved Hide resolved
if timedelta(seconds=(time.time() - start)) > timeout:
raise RuntimeError("Timed out initializing process group")
raise RuntimeError(
"Timed out initializing process group in store based barrier on "
pritamdamania87 marked this conversation as resolved.
Show resolved Hide resolved
"rank: {}, for key: {} (world_size={}, worker_count={})".format(
rank, store_key, world_size, worker_count))

def _rank_not_in_group(group: ProcessGroup):
"""
Expand Down Expand Up @@ -504,12 +505,8 @@ def init_process_group(backend,
# barrier at the end to ensure that once we return from this method, all
# process groups including global variables are updated correctly on all
# ranks.
if backend == Backend.MPI or not (
isinstance(store, TCPStore) or
isinstance(store, FileStore) or
(sys.platform != 'win32' and isinstance(store, HashStore))
):
# MPI doesn't have store.
if backend == Backend.MPI:
# MPI backend doesn't use store.
barrier()
else:
# Use store based barrier here since barrier() used a bunch of
Expand Down Expand Up @@ -2491,16 +2488,12 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None):
# barrier at the end to ensure that once we return from this method, all
# process groups including global variables are updated correctly on all
# ranks.
if backend == Backend.MPI or not (
isinstance(default_store, TCPStore) or
isinstance(default_store, FileStore) or
(sys.platform != 'win32' and isinstance(default_store, HashStore))
):
if backend == Backend.MPI:
# MPI doesn't have store.
barrier()
else:
# Use store based barrier here since barrier() used a bunch of
# default devices and messes up NCCL internal state.
_store_based_barrier(group_rank, default_store, timeout)
_store_based_barrier(global_rank, default_store, timeout)

return pg