Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c10d] Increment sequence numbers on collectives. #55718

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
127 changes: 126 additions & 1 deletion test/distributed/test_c10d.py
Expand Up @@ -4776,7 +4776,12 @@ def op_timeout_sec(self):

@property
def world_size(self):
return 2
# Would like to use torch.cuda.device_count() here, but runs into
# CUDA re-init in forked subprocess error.
if os.environ.get("WORLD_SIZE", None) is not None:
return int(os.environ["WORLD_SIZE"])
else:
return 2

def _test_broadcast_coalesced(self, process_group, device, root_rank):
half = torch.float16
Expand Down Expand Up @@ -4846,6 +4851,126 @@ def test_broadcast_coalesced_gloo_cpu(self):
for root_rank in ranks:
self._test_broadcast_coalesced(process_group, device, root_rank)

def _verify_sequence_number_across_pg(self, pg, verify_pg):

seq_num = pg._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
# We use a separate pg to verify the sequence numbers, otherwise these
# collectives will themselves increment the sequence number.
dist.all_gather_object(obj_list, seq_num, group=verify_pg)
self.assertEqual(len(set(obj_list)), 1)
return obj_list[0]

def _test_sequence_num_incremented(self, process_group, ranks):
# verify initial sequence numbers. Use a distinct process group for
# verification to keep counts as expected with respect to process_group.
verify_pg = dist.new_group(
ranks=ranks,
backend="gloo",
)
assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg)

initial_num = (
self._verify_sequence_number_across_pg(
pg=process_group, verify_pg=verify_pg
)
if not c10d.distributed_c10d._rank_not_in_group(process_group)
else -1
)
rohan-varma marked this conversation as resolved.
Show resolved Hide resolved
# Verify sequence numbers are appropriately incremented
for i in range(10):
t = torch.ones(1, device=torch.cuda.current_device())
dist.all_reduce(t, group=process_group)
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = self._verify_sequence_number_across_pg(
pg=process_group,
verify_pg=verify_pg,
)
self.assertEqual(initial_num + i + 1, seq_num)

if dist.get_world_size(process_group) > 2:
# Test when certain ranks don't call collectives
if dist.get_rank(process_group) not in [0, 1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be better to also add a non-contiguous number like 3 in the list of [0, 1].

dist.all_reduce(t, group=process_group, async_op=True)
# Now ranks 0 and 1 should be lagging by 1.
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = process_group._get_sequence_number_for_group()
rank = dist.get_rank(process_group)
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
rank_to_seq_num = {rank: num for (rank, num) in obj_list}
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[1])
expected_same = {
rank_to_seq_num[i]
for i in rank_to_seq_num.keys()
if i not in [0, 1]
}
self.assertEqual(len(expected_same), 1)
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[2])

@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_sequence_num_incremented_gloo_default(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"gloo",
world_size=self.world_size,
rank=self.rank,
store=store,
)
self._test_sequence_num_incremented(
c10d.distributed_c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
)

@skip_if_lt_x_gpu(4)
@requires_gloo()
def test_sequence_num_incremented_gloo_subgroup(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"gloo",
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup_ranks = [0, 1, 2]
subgroup = dist.new_group(subgroup_ranks)
self._test_sequence_num_incremented(subgroup, subgroup_ranks)

@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_sequence_num_incremented_nccl_subgroup(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Create a helper function to avoid the boilerplate code here, which is same as test_sequence_num_incremented_gloo_subgroup? It seems that only one arg differs in these two tests.

torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup_ranks = [0, 1, 2]
subgroup = dist.new_group(subgroup_ranks)
self._test_sequence_num_incremented(subgroup, subgroup_ranks)

@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_sequence_num_incremented_nccl_default(self):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
self._test_sequence_num_incremented(
c10d.distributed_c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
)

def _test_sequence_num_set_default_pg(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/ProcessGroupGloo.cpp
Expand Up @@ -737,6 +737,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) {

void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
std::unique_lock<std::mutex> lock(workMutex_);
// Bump collective counter
sequenceNum_->increment();
workQueue_.push_back(std::move(work));
lock.unlock();

Expand Down
3 changes: 3 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1048,6 +1048,9 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
PostProcess post,
OpType opType,
const char* profilingTitle) {

// Bump collective counter
sequenceNum_->increment();
const auto devices = getDeviceList(inputs);
const auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices, opType);
Expand Down