-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[10/N] Update barrier with CPU/CUDA implementations #86368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d7b2fdb
3c69c7c
6679593
51f3b75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1464,7 +1464,9 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args): | |
# ensure supported devices (cpu, cuda) succeeds during dispatch call | ||
tensor = torch.zeros(2, 2, device=torch.device(device)) | ||
# multi tensor collectives | ||
if collective == dist.all_gather: | ||
if collective == dist.barrier: | ||
collective() | ||
elif collective == dist.all_gather: | ||
collective([tensor], tensor, *args) | ||
elif collective == dist.reduce_scatter: | ||
if backend != "gloo": | ||
|
@@ -1488,6 +1490,7 @@ def _test_collectives(self, backend): | |
(dist.all_reduce,), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I may have missed this earlier. Would passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing self.rank causes the broadcast operation to be sourced from that rank, but it does not hang waiting for other ranks to ACK the broadcast it sent. The same logic applies to reduce so thats why I believe it is not hanging |
||
(dist.all_gather,), | ||
(dist.reduce_scatter,), | ||
(dist.barrier,), | ||
] | ||
for collective, *args in collectives_and_args: | ||
with self.subTest(collective=collective, args=args): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,6 +219,22 @@ reduce_scatter_cuda_( | |
output_tensors, work); | ||
} | ||
|
||
c10::intrusive_ptr<Work> barrier_cpu( | ||
const c10::intrusive_ptr<ProcessGroup>& process_group, | ||
const std::vector<int64_t>& device_ids, | ||
int64_t timeout) { | ||
return process_group->barrier( | ||
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); | ||
} | ||
|
||
c10::intrusive_ptr<Work> barrier_cuda( | ||
const c10::intrusive_ptr<ProcessGroup>& process_group, | ||
const std::vector<int64_t>& device_ids, | ||
int64_t timeout) { | ||
return process_group->barrier( | ||
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); | ||
} | ||
|
||
// register functions to dispatcher | ||
namespace { | ||
TORCH_LIBRARY_IMPL(c10d, CPU, m) { | ||
|
@@ -286,6 +302,15 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) { | |
TORCH_LIBRARY_IMPL(c10d, CUDA, m) { | ||
m.impl("reduce_scatter_", reduce_scatter_cuda_); | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(c10d, CPU, m) { | ||
m.impl("barrier", barrier_cpu); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you decide if there should be a trailing underscore? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the convention for PT operators (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml) is if the tensor is modified inplace, then operator should be appended with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you Professor Huang! |
||
} | ||
|
||
TORCH_LIBRARY_IMPL(c10d, CUDA, m) { | ||
m.impl("barrier", barrier_cuda); | ||
} | ||
|
||
} // namespace | ||
|
||
} // namespace ops | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the test would be easier to read if we just write out each test call in _test_collectives, like:
collective([tensor], tensor, *args)
a correct format forall_gather
? i.e. it will have a list of only one tensor for the output. Or are we testing the dispatching functionality with WORLD_SIZE=1 here? (if so the code makes sense)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.