Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[reland] Support torch.distributed.irecv(src=None, ...) #49383

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 11 additions & 6 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -695,15 +695,16 @@ def isend(tensor,


def irecv(tensor,
src,
src=None,
group=None,
tag=0):
"""
Receives a tensor asynchronously.

Arguments:
tensor (Tensor): Tensor to fill with received data.
src (int): Source rank.
src (int, optional): Source rank. Will receive from any
process if unspecified.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
tag (int, optional): Tag to match recv with remote send
Expand All @@ -718,11 +719,15 @@ def irecv(tensor,
return

if group is None or group is GroupMember.WORLD:
default_pg = _get_default_group()
return default_pg.recv([tensor], src, tag)
pg = _get_default_group()
else:
group_src_rank = _get_group_rank(group, src)
return group.recv([tensor], group_src_rank, tag)
pg = group

if src is None:
return pg.recv_anysource([tensor], tag)
else:
group_src_rank = _get_group_rank(pg, src)
return pg.recv([tensor], group_src_rank, tag)


def send(tensor,
Expand Down
51 changes: 40 additions & 11 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -855,27 +855,56 @@ def test_send_recv(self):
def test_send_recv_any_source(self):
rank = dist.get_rank()
tensor = _build_tensor(10, value=rank)
recv_ranks = set()
recv_ranks = list()
irecv_ranks = list()

for dst in range(0, dist.get_world_size()):
if dst == rank:
# Recv mode
for dst in range(0, dist.get_world_size()):
if dst == rank:
continue
output_tensor = _build_tensor(10, value=-1)
sender = dist.recv(output_tensor)

# Assert the scalar value "sender" that should be
# equal to the rank of the sender is equal to all
# values in the received tensor.
self.assertTrue(output_tensor.eq(sender).all())
recv_ranks.add(sender)
for recv in ["recv", "irecv"]:
output_tensor = _build_tensor(10, value=-1)

if recv == "recv":
sender = dist.recv(output_tensor)
recv_ranks.append(sender)
elif recv == "irecv":
work = dist.irecv(output_tensor)
work.wait()
sender = work._source_rank()
irecv_ranks.append(sender)

# Assert the scalar value "sender" that should be
# equal to the rank of the sender is equal to all
# values in the received tensor.
self.assertTrue(output_tensor.eq(sender).all())
else:
# Send mode
dist.send(tensor, dst)

self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
dist.send(tensor, dst) # recv
dist.send(tensor, dst) # irecv

# Each rank would have 2 * (world_size - 1) sends, verify that
# globally we receive the same amount on the other end.
recv_ranks_tensor = torch.cat((torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0)
global_recv_ranks = [
torch.empty_like(recv_ranks_tensor),
torch.empty_like(recv_ranks_tensor),
torch.empty_like(recv_ranks_tensor),
torch.empty_like(recv_ranks_tensor),
]
dist.all_gather(global_recv_ranks, recv_ranks_tensor)
global_recv_ranks_list = []
for tensor in global_recv_ranks:
global_recv_ranks_list += tensor.tolist()

from itertools import groupby
global_recv_ranks_list.sort()
frequency = [len(list(group)) for key, group in groupby(global_recv_ranks_list)]
self.assertEqual(dist.get_world_size(), len(frequency))
self.assertEqual([2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation here has been made more robust compared to #47137, since recvAnySource can potentially recv from anywhere.

self._barrier()

# SEND RECV WITH TAG
Expand Down