Skip to content

Commit

Permalink
[c10d] Increment sequence numbers on collectives.
Browse files Browse the repository at this point in the history
Increments sequence numbers when ProcessGroupGloo::enqueue or
ProcessGroupNCCL::collective is run, which is a common call all collectives
make. The next step will be to log these along with other collective info in
debug mode as well as integrating them with the process group wrapper.

Differential Revision: [D27690690](https://our.internmc.facebook.com/intern/diff/D27690690/)

ghstack-source-id: 126216756
Pull Request resolved: #55718
  • Loading branch information
rohan-varma committed Apr 9, 2021
1 parent 45040a8 commit d8aeab9
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 1 deletion.
127 changes: 126 additions & 1 deletion test/distributed/test_c10d.py
Expand Up @@ -4776,7 +4776,12 @@ def op_timeout_sec(self):

@property
def world_size(self):
return 2
# Would like to use torch.cuda.device_count() here, but runs into
# CUDA re-init in forked subprocess error.
if os.environ.get("WORLD_SIZE", None) is not None:
return int(os.environ["WORLD_SIZE"])
else:
return 2

def _test_broadcast_coalesced(self, process_group, device, root_rank):
half = torch.float16
Expand Down Expand Up @@ -4846,6 +4851,126 @@ def test_broadcast_coalesced_gloo_cpu(self):
for root_rank in ranks:
self._test_broadcast_coalesced(process_group, device, root_rank)

def _verify_sequence_number_across_pg(self, pg, verify_pg):

seq_num = pg._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
# We use a separate pg to verify the sequence numbers, otherwise these
# collectives will themselves increment the sequence number.
dist.all_gather_object(obj_list, seq_num, group=verify_pg)
self.assertEqual(len(set(obj_list)), 1)
return obj_list[0]

def _test_sequence_num_incremented(self, process_group, ranks):
# verify initial sequence numbers. Use a distinct process group for
# verification to keep counts as expected with respect to process_group.
verify_pg = dist.new_group(
ranks=ranks,
backend="gloo",
)
assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg)

initial_num = (
self._verify_sequence_number_across_pg(
pg=process_group, verify_pg=verify_pg
)
if not c10d.distributed_c10d._rank_not_in_group(process_group)
else -1
)
# Verify sequence numbers are appropriately incremented
for i in range(10):
t = torch.ones(1, device=torch.cuda.current_device())
dist.all_reduce(t, group=process_group)
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = self._verify_sequence_number_across_pg(
pg=process_group,
verify_pg=verify_pg,
)
self.assertEqual(initial_num + i + 1, seq_num)

if dist.get_world_size(process_group) > 2:
# Test when certain ranks don't call collectives
if dist.get_rank(process_group) not in [0, 1]:
dist.all_reduce(t, group=process_group, async_op=True)
# Now ranks 0 and 1 should be lagging by 1.
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = process_group._get_sequence_number_for_group()
rank = dist.get_rank(process_group)
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
rank_to_seq_num = {rank: num for (rank, num) in obj_list}
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[1])
expected_same = {
rank_to_seq_num[i]
for i in rank_to_seq_num.keys()
if i not in [0, 1]
}
self.assertEqual(len(expected_same), 1)
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[2])

@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_sequence_num_incremented_gloo_default(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"gloo",
world_size=self.world_size,
rank=self.rank,
store=store,
)
self._test_sequence_num_incremented(
c10d.distributed_c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
)

@skip_if_lt_x_gpu(4)
@requires_gloo()
def test_sequence_num_incremented_gloo_subgroup(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"gloo",
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup_ranks = [0, 1, 2]
subgroup = dist.new_group(subgroup_ranks)
self._test_sequence_num_incremented(subgroup, subgroup_ranks)

@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_sequence_num_incremented_nccl_subgroup(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup_ranks = [0, 1, 2]
subgroup = dist.new_group(subgroup_ranks)
self._test_sequence_num_incremented(subgroup, subgroup_ranks)

@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_sequence_num_incremented_nccl_default(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
self._test_sequence_num_incremented(
c10d.distributed_c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
)

def _test_sequence_num_set_default_pg(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/ProcessGroupGloo.cpp
Expand Up @@ -737,6 +737,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) {

void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
// Bump collective counter
sequenceNum_->increment();
workQueue_.push_back(std::move(work));
lock.unlock();

Expand Down
3 changes: 3 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1048,6 +1048,9 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
PostProcess post,
OpType opType,
const char* profilingTitle) {

// Bump collective counter
sequenceNum_->increment();
const auto devices = getDeviceList(inputs);
const auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices, opType);
Expand Down

0 comments on commit d8aeab9

Please sign in to comment.