Skip to content

Commit

Permalink
[c10d] Increment sequence numbers on collectives.
Browse files Browse the repository at this point in the history
Pull Request resolved: #55718

Increments sequence numbers when ProcessGroupGloo::enqueue or
ProcessGroupNCCL::collective is run, which is a common call all collectives
make. The next step will be to log these along with other collective info in
debug mode as well as integrating them with the process group wrapper.
ghstack-source-id: 126896607

Differential Revision: [D27690690](https://our.internmc.facebook.com/intern/diff/D27690690/)
  • Loading branch information
rohan-varma committed Apr 19, 2021
1 parent fe41d7d commit 76e7313
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 1 deletion.
126 changes: 125 additions & 1 deletion test/distributed/test_c10d.py
Expand Up @@ -4835,7 +4835,11 @@ def op_timeout_sec(self):

@property
def world_size(self):
return 2
# Test runs with world size of 2 in CI, but can be configured via
# WORLD_SIZE env var for dev purposes. Would like to use
# torch.cuda.device_count() here, but runs into CUDA re-init in forked
# subprocess error.
return os.environ.get("WORLD_SIZE", 2)

def _test_broadcast_coalesced(self, process_group, device, root_rank):
half = torch.float16
Expand Down Expand Up @@ -4905,6 +4909,126 @@ def test_broadcast_coalesced_gloo_cpu(self):
for root_rank in ranks:
self._test_broadcast_coalesced(process_group, device, root_rank)

def _verify_sequence_number_across_pg(self, pg, verify_pg):

seq_num = pg._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
# We use a separate pg to verify the sequence numbers, otherwise these
# collectives will themselves increment the sequence number.
dist.all_gather_object(obj_list, seq_num, group=verify_pg)
self.assertEqual(len(set(obj_list)), 1)
return obj_list[0]

def _test_sequence_num_incremented(self, process_group, ranks):
# verify initial sequence numbers. Use a distinct process group for
# verification to keep counts as expected with respect to process_group.
verify_pg = dist.new_group(
ranks=ranks,
backend="gloo",
)
assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg)

initial_num = (
self._verify_sequence_number_across_pg(
pg=process_group, verify_pg=verify_pg
)
if not c10d.distributed_c10d._rank_not_in_group(process_group)
else -1
)
# Verify sequence numbers are appropriately incremented
for i in range(10):
t = torch.ones(1, device=torch.cuda.current_device())
dist.all_reduce(t, group=process_group)
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = self._verify_sequence_number_across_pg(
pg=process_group,
verify_pg=verify_pg,
)
self.assertEqual(initial_num + i + 1, seq_num)

if dist.get_world_size(process_group) > 2:
# Test when certain ranks don't call collectives
if dist.get_rank(process_group) not in [0, 1]:
dist.all_reduce(t, group=process_group, async_op=True)
# Now ranks 0 and 1 should be lagging by 1.
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = process_group._get_sequence_number_for_group()
rank = dist.get_rank(process_group)
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
rank_to_seq_num = {rank: num for (rank, num) in obj_list}
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[1])
expected_same = {
rank_to_seq_num[i]
for i in rank_to_seq_num.keys()
if i not in [0, 1]
}
self.assertEqual(len(expected_same), 1)
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[2])

@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_sequence_num_incremented_gloo_default(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"gloo",
world_size=self.world_size,
rank=self.rank,
store=store,
)
self._test_sequence_num_incremented(
c10d.distributed_c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
)

@skip_if_lt_x_gpu(4)
@requires_gloo()
def test_sequence_num_incremented_gloo_subgroup(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"gloo",
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup_ranks = [0, 1, 2]
subgroup = dist.new_group(subgroup_ranks)
self._test_sequence_num_incremented(subgroup, subgroup_ranks)

@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_sequence_num_incremented_nccl_subgroup(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup_ranks = [0, 1, 2]
subgroup = dist.new_group(subgroup_ranks)
self._test_sequence_num_incremented(subgroup, subgroup_ranks)

@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_sequence_num_incremented_nccl_default(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
self._test_sequence_num_incremented(
c10d.distributed_c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
)

def _test_sequence_num_set_default_pg(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/ProcessGroupGloo.cpp
Expand Up @@ -708,6 +708,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) {

void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
// Bump collective counter
sequenceNum_->increment();
workQueue_.push_back(std::move(work));
lock.unlock();

Expand Down
3 changes: 3 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1058,6 +1058,9 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
PostProcess post,
OpType opType,
const char* profilingTitle) {

// Bump collective counter
sequenceNum_->increment();
const auto devices = getDeviceList(inputs);
const auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices, opType);
Expand Down

0 comments on commit 76e7313

Please sign in to comment.