Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in native::_do_all_gather related to group #2947

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 11 additions & 6 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,20 @@
return tensor

def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
if group is not None and not isinstance(group, dist.ProcessGroup):
if group == dist.GroupMember.NON_GROUP_MEMBER:
return tensor

Check warning on line 434 in ignite/distributed/comp_models/native.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/native.py#L434

Added line #L434 was not covered by tests
elif group is None:
group_size = self.get_world_size()
elif isinstance(group, dist.ProcessGroup):
group_size = group.size()
elif isinstance(group, list):
group_size = len(group)

Check warning on line 440 in ignite/distributed/comp_models/native.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/native.py#L440

Added line #L440 was not covered by tests
else:
raise ValueError("Argument group should be list of int or ProcessGroup")
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
output = [torch.zeros_like(tensor) for _ in range(self.get_world_size())]
if group is not None:
dist.all_gather(output, tensor, group=group)
else:
dist.all_gather(output, tensor)
output = [torch.zeros_like(tensor) for _ in range(group_size)]
dist.all_gather(output, tensor, group=group)
return torch.cat(output, dim=0)

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
Expand Down
10 changes: 6 additions & 4 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,15 @@ def all_gather(
"""Helper method to perform all gather operation.

Args:
tensor: tensor or number or str to collect across participating processes.
tensor: tensor or number or str to collect across participating processes. If tensor, it should have the
same shape across processes.
group: list of integer or the process group for each backend. If None, the default process group will be used.

Returns:
torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or
torch.Tensor of shape ``(world_size, )`` if input is a number or
List of strings if input is a string
If input is a tensor, returns a torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)``.
If input is a number, a torch.Tensor of shape ``(world_size, )`` is returned and finally a list of strings
is returned if input is a string. If current process does not belong to `group`, the very ``tensor`` is
returned.

.. versionchanged:: 0.4.11
added ``group``
Expand Down
22 changes: 14 additions & 8 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,32 +155,34 @@ def _test_distrib_all_reduce_group(device):


def _test_distrib_all_gather(device):
rank = idist.get_rank()

res = torch.tensor(idist.all_gather(10), device=device)
true_res = torch.tensor([10] * idist.get_world_size(), device=device)
assert (res == true_res).all()

t = torch.tensor(idist.get_rank(), device=device)
t = torch.tensor(rank, device=device)
res = idist.all_gather(t)
true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device)
assert (res == true_res).all()

x = "test-test"
if idist.get_rank() == 0:
if rank == 0:
x = "abc"
res = idist.all_gather(x)
true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1)
assert res == true_res

base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
x = base_x
if idist.get_rank() == 0:
if rank == 0:
x = "abc"

res = idist.all_gather(x)
true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1)
assert res == true_res

t = torch.arange(100, device=device).reshape(4, 25) * (idist.get_rank() + 1)
t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1)
in_dtype = t.dtype
res = idist.all_gather(t)
assert res.shape == (idist.get_world_size() * 4, 25)
Expand Down Expand Up @@ -208,17 +210,21 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group=group)
else:
res = idist.all_gather(t, group=group)
assert torch.equal(res, torch.tensor(ranks, device=device))
if rank in ranks:
assert torch.equal(res, torch.tensor(ranks, device=device))
else:
assert res == t

t = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=ranks)
else:
res = idist.all_gather(t, group=ranks)
assert torch.equal(res, torch.tensor(ranks, device=device))

ranks = "abc"
if rank in ranks:
assert torch.equal(res, torch.tensor(ranks, device=device))
else:
assert res == t

if bnd in ("nccl", "gloo", "mpi"):
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
Expand Down