New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[reland] Support torch.distributed.irecv(src=None, ...) #49383
[reland] Support torch.distributed.irecv(src=None, ...) #49383
Conversation
Reland of #47137 Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/) [ghstack-poisoned]
Reland of #47137 Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/) ghstack-source-id: 118586219 Pull Request resolved: #49383
# Each rank would have 2 * (world_size - 1) sends, verify that | ||
# globally we receive the same amount on the other end. | ||
recv_ranks_tensor = torch.cat((torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0) | ||
global_recv_ranks = [ | ||
torch.empty_like(recv_ranks_tensor), | ||
torch.empty_like(recv_ranks_tensor), | ||
torch.empty_like(recv_ranks_tensor), | ||
torch.empty_like(recv_ranks_tensor), | ||
] | ||
dist.all_gather(global_recv_ranks, recv_ranks_tensor) | ||
global_recv_ranks_list = [] | ||
for tensor in global_recv_ranks: | ||
global_recv_ranks_list += tensor.tolist() | ||
|
||
from itertools import groupby | ||
global_recv_ranks_list.sort() | ||
frequency = [len(list(group)) for key, group in groupby(global_recv_ranks_list)] | ||
self.assertEqual(dist.get_world_size(), len(frequency)) | ||
self.assertEqual([2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation here has been made more robust compared to #47137, since recvAnySource can potentially recv from anywhere.
💊 CI failures summary and remediationsAs of commit ff34456 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_xla_linux_bionic_py3_6_clang9_build (1/1)Step: "Build" (full log | diagnosis details | 🔁 rerun)
|
Reland of #47137 Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/) [ghstack-poisoned]
Pull Request resolved: #49383 Reland of #47137 ghstack-source-id: 118669599 Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/)
return pg.recv_anysource([tensor], tag) | ||
else: | ||
if pg is GroupMember.WORLD: | ||
pg.recv([tensor], src, tag).wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we block in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, this is a bug.
Reland of #47137 Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/) [ghstack-poisoned]
Pull Request resolved: #49383 Reland of #47137 ghstack-source-id: 118686023 Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/)
Reland of #47137 Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/) [ghstack-poisoned]
Pull Request resolved: #49383 Reland of #47137 ghstack-source-id: 118735407 Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/)
This pull request has been merged in db2ecef. |
Summary: Pull Request resolved: pytorch#49383 Reland of pytorch#47137 ghstack-source-id: 118735407 Test Plan: waitforbuildbot Reviewed By: osalpekar Differential Revision: D25551910 fbshipit-source-id: 2e1f2f77e7c69204056dfe6ed178e8ad7650ab32
Stack from ghstack:
Reland of #47137
Differential Revision: D25551910