Skip to content

Commit

Permalink
[Experimental] Remove store barrier after PG init (#99937)
Browse files Browse the repository at this point in the history
Store based barrier is not scalable.
Experimenting to see if removing it breaks any CI

Pull Request resolved: #99937
Approved by: https://github.com/kumpera, https://github.com/H-Huang
  • Loading branch information
kwen2501 authored and pytorchmergebot committed Apr 27, 2023
1 parent 7bece14 commit ae0eb23
Showing 1 changed file with 49 additions and 21 deletions.
70 changes: 49 additions & 21 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,13 @@ def _get_pg_device(group: ProcessGroup):
return torch.device("cpu")


# Environment variable to control whether we do a barrier after process group
# init. Default value is 1 for now to stay the same with previous behavior.
# Users can change it to 0 if such behavior is undesired. We reserve the right
# to change the default value to 0 if small rollout is successful.
_barrier_after_init = int(os.getenv("TORCH_DIST_INIT_BARRIER", "1"))


def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, logging_interval=timedelta(seconds=10)):
"""
Barrier based on store which is used for synchronizing processes after
Expand Down Expand Up @@ -1025,16 +1032,27 @@ def init_process_group(
_backend = _world.pg_map[GroupMember.WORLD][0] # type: ignore[index]
_default_pg_init_method = init_method

# barrier at the end to ensure that once we return from this method, all
# process groups including global variables are updated correctly on all
# ranks.
if backend == Backend.MPI:
# MPI backend doesn't use store.
barrier()
if _barrier_after_init == 1:
# barrier at the end to ensure that once we return from this method, all
# process groups including global variables are updated correctly on all
# ranks.
# Update 04/2023: for large-scale runs, this barrier (esp. store-based
# barrier) may be costly and/or unscalable. Also, in a lot of cases,
# these barriers may be unnecessary, as proved by a green CI after
# removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
# added which, when set to 0, will disable these barriers.
if backend == Backend.MPI:
# MPI backend doesn't use store.
barrier()
else:
# Use store based barrier here since barrier() used a bunch of
# default devices and messes up NCCL internal state.
_store_based_barrier(rank, store, group_name, world_size, timeout)
else:
# Use store based barrier here since barrier() used a bunch of
# default devices and messes up NCCL internal state.
_store_based_barrier(rank, store, group_name, world_size, timeout)
logger.info(
"TORCH_DIST_INIT_BARRIER is set to 0, omitting the barrier after "
"ProcessGroup initialization."
)


def _new_process_group_helper(
Expand Down Expand Up @@ -3785,19 +3803,29 @@ def _new_group_with_tag(
global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
}

# barrier at the end to ensure that once we return from this method, all
# process groups including global variables are updated correctly on all
# ranks.
if backend == Backend.MPI:
# MPI doesn't have store.
barrier()
if _barrier_after_init == 1:
# barrier at the end to ensure that once we return from this method, all
# process groups including global variables are updated correctly on all
# ranks.
# Update 04/2023: for large-scale runs, these barriers (esp. store-based
# barrier) may be costly and/or unscalable. Also, in a lot of cases,
# these barriers may be unnecessary, as proved by a green CI after
# removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
# added which, when set to 0, will disable these barriers.
if backend == Backend.MPI:
# MPI doesn't have store.
barrier()
else:
barrier_store = pg_store if use_local_synchronization else default_store
world_size = len(ranks) if use_local_synchronization else get_world_size()
# Use store based barrier here since barrier() used a bunch of
# default devices and messes up NCCL internal state.
_store_based_barrier(global_rank, barrier_store, group_name, world_size, timeout)
else:
barrier_store = pg_store if use_local_synchronization else default_store
world_size = len(ranks) if use_local_synchronization else get_world_size()

# Use store based barrier here since barrier() used a bunch of
# default devices and messes up NCCL internal state.
_store_based_barrier(global_rank, barrier_store, group_name, world_size, timeout)
logger.info(
"TORCH_DIST_INIT_BARRIER is set to 0, omitting the barrier after "
"ProcessGroup initialization."
)

return pg

Expand Down

0 comments on commit ae0eb23

Please sign in to comment.