diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 387da70403b0..5f51577c5b19 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -695,7 +695,7 @@ def isend(tensor, def irecv(tensor, - src, + src=None, group=None, tag=0): """ @@ -703,7 +703,8 @@ def irecv(tensor, Arguments: tensor (Tensor): Tensor to fill with received data. - src (int): Source rank. + src (int, optional): Source rank. Will receive from any + process if unspecified. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. tag (int, optional): Tag to match recv with remote send @@ -718,11 +719,18 @@ def irecv(tensor, return if group is None or group is GroupMember.WORLD: - default_pg = _get_default_group() - return default_pg.recv([tensor], src, tag) + pg = _get_default_group() else: - group_src_rank = _get_group_rank(group, src) - return group.recv([tensor], group_src_rank, tag) + pg = group + + if src is None: + return pg.recv_anysource([tensor], tag) + else: + if pg is GroupMember.WORLD: + return pg.recv([tensor], src, tag) + else: + group_src_rank = _get_group_rank(pg, src) + return pg.recv([tensor], group_src_rank, tag) def send(tensor, diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 5577d2322679..a4c3953bf2b6 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -855,7 +855,8 @@ def test_send_recv(self): def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, value=rank) - recv_ranks = set() + recv_ranks = list() + irecv_ranks = list() for dst in range(0, dist.get_world_size()): if dst == rank: @@ -863,19 +864,42 @@ def test_send_recv_any_source(self): for dst in range(0, dist.get_world_size()): if dst == rank: continue - output_tensor = _build_tensor(10, value=-1) - sender = dist.recv(output_tensor) - # Assert the scalar value "sender" that should be - # equal to the rank of the sender is equal to all - # values in the received tensor. - self.assertTrue(output_tensor.eq(sender).all()) - recv_ranks.add(sender) + for recv in ["recv", "irecv"]: + output_tensor = _build_tensor(10, value=-1) + + if recv == "recv": + sender = dist.recv(output_tensor) + recv_ranks.append(sender) + elif recv == "irecv": + work = dist.irecv(output_tensor) + work.wait() + sender = work._source_rank() + irecv_ranks.append(sender) + + # Assert the scalar value "sender" that should be + # equal to the rank of the sender is equal to all + # values in the received tensor. + self.assertTrue(output_tensor.eq(sender).all()) else: # Send mode - dist.send(tensor, dst) - - self.assertEqual(len(recv_ranks), dist.get_world_size() - 1) + dist.send(tensor, dst) # recv + dist.send(tensor, dst) # irecv + + # Each rank would have 2 * (world_size - 1) sends, verify that + # globally we receive the same amount on the other end. + recv_ranks_tensor = torch.cat((torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0) + global_recv_ranks = [torch.empty_like(recv_ranks_tensor) for _ in range(dist.get_world_size())] + dist.all_gather(global_recv_ranks, recv_ranks_tensor) + global_recv_ranks_list = [] + for tensor in global_recv_ranks: + global_recv_ranks_list += tensor.tolist() + + from itertools import groupby + global_recv_ranks_list.sort() + frequency = [len(list(group)) for key, group in groupby(global_recv_ranks_list)] + self.assertEqual(dist.get_world_size(), len(frequency)) + self.assertEqual([2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency) self._barrier() # SEND RECV WITH TAG