Skip to content

Commit

Permalink
[reland] Support torch.distributed.irecv(src=None, ...)
Browse files Browse the repository at this point in the history
Reland of #47137

Differential Revision: [D25551910](https://our.internmc.facebook.com/intern/diff/D25551910/)

ghstack-source-id: 118586219
Pull Request resolved: #49383
  • Loading branch information
pritamdamania committed Dec 15, 2020
1 parent f54ab8f commit 8b560c0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
17 changes: 11 additions & 6 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -695,15 +695,16 @@ def isend(tensor,


def irecv(tensor,
src,
src=None,
group=None,
tag=0):
"""
Receives a tensor asynchronously.
Arguments:
tensor (Tensor): Tensor to fill with received data.
src (int): Source rank.
src (int, optional): Source rank. Will receive from any
process if unspecified.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
tag (int, optional): Tag to match recv with remote send
Expand All @@ -718,11 +719,15 @@ def irecv(tensor,
return

if group is None or group is GroupMember.WORLD:
default_pg = _get_default_group()
return default_pg.recv([tensor], src, tag)
pg = _get_default_group()
else:
group_src_rank = _get_group_rank(group, src)
return group.recv([tensor], group_src_rank, tag)
pg = group

if src is None:
return pg.recv_anysource([tensor], tag)
else:
group_src_rank = _get_group_rank(pg, src)
return pg.recv([tensor], group_src_rank, tag)


def send(tensor,
Expand Down
51 changes: 40 additions & 11 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -855,27 +855,56 @@ def test_send_recv(self):
def test_send_recv_any_source(self):
rank = dist.get_rank()
tensor = _build_tensor(10, value=rank)
recv_ranks = set()
recv_ranks = list()
irecv_ranks = list()

for dst in range(0, dist.get_world_size()):
if dst == rank:
# Recv mode
for dst in range(0, dist.get_world_size()):
if dst == rank:
continue
output_tensor = _build_tensor(10, value=-1)
sender = dist.recv(output_tensor)

# Assert the scalar value "sender" that should be
# equal to the rank of the sender is equal to all
# values in the received tensor.
self.assertTrue(output_tensor.eq(sender).all())
recv_ranks.add(sender)
for recv in ["recv", "irecv"]:
output_tensor = _build_tensor(10, value=-1)

if recv == "recv":
sender = dist.recv(output_tensor)
recv_ranks.append(sender)
elif recv == "irecv":
work = dist.irecv(output_tensor)
work.wait()
sender = work._source_rank()
irecv_ranks.append(sender)

# Assert the scalar value "sender" that should be
# equal to the rank of the sender is equal to all
# values in the received tensor.
self.assertTrue(output_tensor.eq(sender).all())
else:
# Send mode
dist.send(tensor, dst)

self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
dist.send(tensor, dst) # recv
dist.send(tensor, dst) # irecv

# Each rank would have 2 * (world_size - 1) sends, verify that
# globally we receive the same amount on the other end.
recv_ranks_tensor = torch.cat((torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0)
global_recv_ranks = [
torch.empty_like(recv_ranks_tensor),
torch.empty_like(recv_ranks_tensor),
torch.empty_like(recv_ranks_tensor),
torch.empty_like(recv_ranks_tensor),
]
dist.all_gather(global_recv_ranks, recv_ranks_tensor)
global_recv_ranks_list = []
for tensor in global_recv_ranks:
global_recv_ranks_list += tensor.tolist()

from itertools import groupby
global_recv_ranks_list.sort()
frequency = [len(list(group)) for key, group in groupby(global_recv_ranks_list)]
self.assertEqual(dist.get_world_size(), len(frequency))
self.assertEqual([2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency)
self._barrier()

# SEND RECV WITH TAG
Expand Down

0 comments on commit 8b560c0

Please sign in to comment.