Skip to content

Commit

Permalink
[c10d] Increment sequence numbers on collectives.
Browse files Browse the repository at this point in the history
Pull Request resolved: #55718

Increments sequence numbers when ProcessGroupGloo::enqueue or
ProcessGroupNCCL::collective is run, which is a common call all collectives
make. The next step will be to log these along with other collective info in
debug mode as well as integrating them with the process group wrapper.
ghstack-source-id: 127099736

Differential Revision: [D27690690](https://our.internmc.facebook.com/intern/diff/D27690690/)
  • Loading branch information
rohan-varma committed Apr 21, 2021
1 parent 8ee1347 commit 2423b51
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 5 deletions.
122 changes: 117 additions & 5 deletions test/distributed/test_c10d.py
Expand Up @@ -4818,7 +4818,12 @@ def op_timeout_sec(self):

@property
def world_size(self):
return 2
# Test runs with world size of 2 in CI, but can be configured via
# DEV_WORLD_SIZE env var for dev purposes. Would like to use
# torch.cuda.device_count() here, but runs into CUDA re-init in forked
# subprocess error. TODO: enable these tests in CI for a larger world
# size.
return int(os.environ.get("DEV_WORLD_SIZE", 2))

def _test_broadcast_coalesced(self, process_group, device, root_rank):
half = torch.float16
Expand Down Expand Up @@ -4888,6 +4893,111 @@ def test_broadcast_coalesced_gloo_cpu(self):
for root_rank in ranks:
self._test_broadcast_coalesced(process_group, device, root_rank)

def _verify_sequence_number_across_pg(self, pg, verify_pg):

seq_num = pg._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
# We use a separate pg to verify the sequence numbers, otherwise these
# collectives will themselves increment the sequence number.
dist.all_gather_object(obj_list, seq_num, group=verify_pg)
self.assertEqual(len(set(obj_list)), 1)
return obj_list[0]

def _test_sequence_num_incremented(self, process_group, ranks):
# verify initial sequence numbers. Use a distinct process group for
# verification to keep counts as expected with respect to process_group.
verify_pg = dist.new_group(
ranks=ranks,
backend="gloo",
)
assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg)

initial_num = (
self._verify_sequence_number_across_pg(
pg=process_group, verify_pg=verify_pg
)
if not c10d.distributed_c10d._rank_not_in_group(process_group)
else -1
)
# Verify sequence numbers are appropriately incremented
for i in range(10):
t = torch.ones(1, device=torch.cuda.current_device())
dist.all_reduce(t, group=process_group)
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = self._verify_sequence_number_across_pg(
pg=process_group,
verify_pg=verify_pg,
)
self.assertEqual(initial_num + i + 1, seq_num)

if dist.get_world_size(process_group) > 2:
# Test when certain ranks don't call collectives
if dist.get_rank(process_group) not in [0, 2]:
dist.all_reduce(t, group=process_group, async_op=True)
# Now ranks 0 and 2 should be lagging by 1.
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = process_group._get_sequence_number_for_group()
rank = dist.get_rank(process_group)
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
rank_to_seq_num = {rank: num for (rank, num) in obj_list}
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
expected_same = {
rank_to_seq_num[i]
for i in rank_to_seq_num.keys()
if i not in [0, 2]
}
self.assertEqual(len(expected_same), 1)
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])

def _test_sequence_num_incremented_default_group(self, backend_name):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend_name,
world_size=self.world_size,
rank=self.rank,
store=store,
)
self._test_sequence_num_incremented(
c10d.distributed_c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
)

@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_sequence_num_incremented_nccl_default(self):
self._test_sequence_num_incremented_default_group("nccl")

@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_sequence_num_incremented_gloo_default(self):
self._test_sequence_num_incremented_default_group("gloo")

def _test_sequence_num_incremented_subgroup(self, backend_name):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend_name,
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup_ranks = [0, 1, 2]
subgroup = dist.new_group(subgroup_ranks)
self._test_sequence_num_incremented(subgroup, subgroup_ranks)

@skip_if_lt_x_gpu(4)
@requires_gloo()
def test_sequence_num_incremented_gloo_subgroup(self):
self._test_sequence_num_incremented_subgroup("gloo")

@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_sequence_num_incremented_nccl_subgroup(self):
self._test_sequence_num_incremented_subgroup("nccl")

def _test_sequence_num_set_default_pg(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
Expand Down Expand Up @@ -4924,10 +5034,12 @@ def _test_sequence_num_set_new_group(self, backend):
)

subgroup = dist.new_group([0, 1])
subgroup_seq = subgroup._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, subgroup_seq)
self.assertEqual(len(set(obj_list)), 1)

if not c10d.distributed_c10d._rank_not_in_group(subgroup):
subgroup_seq = subgroup._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size(subgroup))]
dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
self.assertEqual(len(set(obj_list)), 1)

@requires_gloo()
@skip_if_lt_x_gpu(2)
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/ProcessGroupGloo.cpp
Expand Up @@ -710,6 +710,10 @@ void ProcessGroupGloo::runLoop(int workerIndex) {

void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
// Bump collective counter
if (sequenceNum_) {
sequenceNum_->increment();
}
workQueue_.push_back(std::move(work));
lock.unlock();

Expand Down
5 changes: 5 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1060,6 +1060,11 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
PostProcess post,
OpType opType,
const char* profilingTitle) {

// Bump collective counter
if (sequenceNum_) {
sequenceNum_->increment();
}
const auto devices = getDeviceList(inputs);
const auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices, opType);
Expand Down

0 comments on commit 2423b51

Please sign in to comment.