Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,9 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args):
# ensure supported devices (cpu, cuda) succeeds during dispatch call
tensor = torch.zeros(2, 2, device=torch.device(device))
# multi tensor collectives
if collective == dist.all_gather:
if collective == dist.barrier:
collective()
elif collective == dist.all_gather:
collective([tensor], tensor, *args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. nit: this if-else block is getting bigger (as a result of wrapping/templating).
    Maybe the test would be easier to read if we just write out each test call in _test_collectives, like:
dist.barrier()
dist.all_reduce(tensor)
...
  1. Is collective([tensor], tensor, *args) a correct format for all_gather? i.e. it will have a list of only one tensor for the output. Or are we testing the dispatching functionality with WORLD_SIZE=1 here? (if so the code makes sense)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Makes sense, the if statement is getting a bit messy, I will change it in future PRs.
  2. We are just testing dispatching functionality and making sure the operation is callable. So handling the latter case, since collective correctness should be handled by other tests.

elif collective == dist.reduce_scatter:
if backend != "gloo":
Expand All @@ -1488,6 +1490,7 @@ def _test_collectives(self, backend):
(dist.all_reduce,),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I may have missed this earlier. Would passing self.rank to reduce and broadcast cause each rank identifying a different root and hang? They should have the same root.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing self.rank causes the broadcast operation to be sourced from that rank, but it does not hang waiting for other ranks to ACK the broadcast it sent. The same logic applies to reduce so thats why I believe it is not hanging

(dist.all_gather,),
(dist.reduce_scatter,),
(dist.barrier,),
]
for collective, *args in collectives_and_args:
with self.subTest(collective=collective, args=args):
Expand Down
25 changes: 25 additions & 0 deletions torch/csrc/distributed/c10d/OpsImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,22 @@ reduce_scatter_cuda_(
output_tensors, work);
}

c10::intrusive_ptr<Work> barrier_cpu(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<int64_t>& device_ids,
int64_t timeout) {
return process_group->barrier(
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> barrier_cuda(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<int64_t>& device_ids,
int64_t timeout) {
return process_group->barrier(
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)});
}

// register functions to dispatcher
namespace {
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
Expand Down Expand Up @@ -286,6 +302,15 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) {
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("reduce_scatter_", reduce_scatter_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("barrier", barrier_cpu);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you decide if there should be a trailing underscore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the convention for PT operators (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml) is if the tensor is modified inplace, then operator should be appended with _. We don't do this for barrier and send

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Professor Huang!

}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("barrier", barrier_cuda);
}

} // namespace

} // namespace ops
Expand Down