From 34a707e53785cf8a524589f33a570a7516fe064e Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 31 Aug 2023 17:10:14 +0200 Subject: [PATCH] Added support for all_gather object (#3047) * Added support for all_gather object * Apply suggestions from code review Co-authored-by: Sadra Barikbin * Added new test in _test_distrib_all_gather_group * Handling pytorch old versions --------- Co-authored-by: Sadra Barikbin --- ignite/distributed/comp_models/base.py | 13 +++- ignite/distributed/comp_models/horovod.py | 6 ++ ignite/distributed/comp_models/native.py | 30 +++++++- ignite/distributed/comp_models/xla.py | 3 + tests/ignite/distributed/utils/__init__.py | 76 ++++++++++++++++--- tests/ignite/distributed/utils/test_native.py | 3 + 6 files changed, 116 insertions(+), 15 deletions(-) diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index 82466f2244d..c22259833fa 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -181,7 +181,7 @@ def _apply_op( return tensor def _collective_op( - self, tensor: Union[torch.Tensor, float, str], fn: Callable, *args: Any, **kwargs: Any + self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any ) -> Union[torch.Tensor, float, List[float], List[str]]: tensor_to_number = tensor_to_str = False device = self.device() @@ -216,10 +216,10 @@ def all_reduce( return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op, group=group)) def all_gather( - self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None + self, tensor: Union[torch.Tensor, float, str, Any], group: Optional[Any] = None ) -> Union[torch.Tensor, float, List[float], List[str]]: if not isinstance(tensor, (torch.Tensor, Number, str)): - raise TypeError(f"Unhandled input type {type(tensor)}") + return self._do_all_gather_object(tensor, group=group) return self._collective_op(tensor, self._do_all_gather, group=group) @@ -282,6 +282,10 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: pass + @abstractmethod + def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]: + pass + @abstractmethod def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: pass @@ -373,6 +377,9 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: return tensor + def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Any: + return tensor + def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: return ranks diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index 62b59d3caea..36f15f4428d 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -192,6 +192,12 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t tensor = tensor.unsqueeze(0) return hvd.allgather(tensor) + def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]: + if group is not None: + raise NotImplementedError("all_gather with group for horovod is not implemented") + + return hvd.allgather_object(tensor) + def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: return hvd.ProcessSet(ranks) diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index ee0d12858d1..c71c7d42311 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -423,6 +423,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[ if group is not None and not isinstance(group, dist.ProcessGroup): raise ValueError("Argument group should be list of int or ProcessGroup") reduce_op = self._reduce_op_map[op] + # We do if/else here for compatibility with older pytorch versions if group is not None: dist.all_reduce(tensor, reduce_op, group=group) else: @@ -432,7 +433,8 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: if group == dist.GroupMember.NON_GROUP_MEMBER: return tensor - elif group is None: + + if group is None: group_size = self.get_world_size() elif isinstance(group, dist.ProcessGroup): group_size = group.size() @@ -441,12 +443,38 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t if tensor.ndimension() == 0: tensor = tensor.unsqueeze(0) output = [torch.zeros_like(tensor) for _ in range(group_size)] + # We do if/else here for compatibility with older pytorch versions if group is not None: dist.all_gather(output, tensor, group=group) else: dist.all_gather(output, tensor) return torch.cat(output, dim=0) + def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]: + if Version(torch.__version__) < Version("1.7.0"): + raise RuntimeError( + "Current torch version does not implement dist.all_gather_object. " + "Required version should be >=1.7.0" + ) + + if group == dist.GroupMember.NON_GROUP_MEMBER: + return tensor + + if group is None: + group_size = self.get_world_size() + elif isinstance(group, dist.ProcessGroup): + group_size = group.size() + else: + raise ValueError("Argument group should be list of int or ProcessGroup") + output = [None for _ in range(group_size)] + # We do if/else here for compatibility with older pytorch versions + if group is not None: + dist.all_gather_object(output, tensor, group=group) + else: + dist.all_gather_object(output, tensor) + + return output + def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: return dist.new_group(ranks=ranks, **kwargs) diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index 2c2cb27c9d1..eaaeceb0252 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -155,6 +155,9 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t xm.all_reduce("sum", [output], groups=group) return output.reshape(-1, *output.shape[2:]) + def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]: + raise NotImplementedError("all_gather on object is not implemented for xla") + def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: return [ranks] diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 23960cac872..7845f0cd1ce 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -156,21 +156,22 @@ def _test_distrib_all_reduce_group(device): def _test_distrib_all_gather(device): rank = idist.get_rank() + ws = idist.get_world_size() res = torch.tensor(idist.all_gather(10), device=device) - true_res = torch.tensor([10] * idist.get_world_size(), device=device) + true_res = torch.tensor([10] * ws, device=device) assert (res == true_res).all() t = torch.tensor(rank, device=device) res = idist.all_gather(t) - true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device) + true_res = torch.tensor([i for i in range(ws)], device=device) assert (res == true_res).all() x = "test-test" if rank == 0: x = "abc" res = idist.all_gather(x) - true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1) + true_res = ["abc"] + ["test-test"] * (ws - 1) assert res == true_res base_x = "tests/ignite/distributed/utils/test_native.py" * 2000 @@ -179,27 +180,46 @@ def _test_distrib_all_gather(device): x = "abc" res = idist.all_gather(x) - true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1) + true_res = ["abc"] + [base_x] * (ws - 1) assert res == true_res t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1) in_dtype = t.dtype res = idist.all_gather(t) - assert res.shape == (idist.get_world_size() * 4, 25) + assert res.shape == (ws * 4, 25) assert res.dtype == in_dtype - true_res = torch.zeros(idist.get_world_size() * 4, 25, device=device) - for i in range(idist.get_world_size()): + true_res = torch.zeros(ws * 4, 25, device=device) + for i in range(ws): true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1) assert (res == true_res).all() - if idist.get_world_size() > 1: - with pytest.raises(TypeError, match=r"Unhandled input type"): - idist.all_reduce([0, 1, 2]) + if ws > 1 and idist.backend() != "xla-tpu": + t = { + "a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)], + "b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device), + "c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)}, + } + res = idist.all_gather(t) + assert isinstance(res, list) and len(res) == ws + for i, obj in enumerate(res): + assert isinstance(obj, dict) + assert list(obj.keys()) == ["a", "b", "c"], obj + expected_device = ( + device if torch.device(device).type == "cpu" else torch.device(f"{torch.device(device).type}:{i}") + ) + expected = { + "a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)], + "b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device), + "c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)}, + } + assert obj["a"] == expected["a"] + assert (obj["b"] == expected["b"]).all() + assert obj["c"] == expected["c"] def _test_distrib_all_gather_group(device): if idist.get_world_size() > 1: - ranks = [0, 1] + ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1] rank = idist.get_rank() bnd = idist.backend() @@ -226,6 +246,40 @@ def _test_distrib_all_gather_group(device): else: assert res == t + t = { + "a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)], + "b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device), + "c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)}, + } + if bnd in ("xla-tpu"): + with pytest.raises(NotImplementedError, match=r"all_gather on object is not implemented for xla"): + res = idist.all_gather(t, group=ranks) + elif bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): + res = idist.all_gather(t, group=ranks) + else: + res = idist.all_gather(t, group=ranks) + if rank in ranks: + assert isinstance(res, list) and len(res) == len(ranks) + for i, obj in zip(ranks, res): + assert isinstance(obj, dict) + assert list(obj.keys()) == ["a", "b", "c"], obj + expected_device = ( + device + if torch.device(device).type == "cpu" + else torch.device(f"{torch.device(device).type}:{i}") + ) + expected = { + "a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)], + "b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device), + "c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)}, + } + assert obj["a"] == expected["a"], (obj, expected) + assert (obj["b"] == expected["b"]).all(), (obj, expected) + assert obj["c"] == expected["c"], (obj, expected) + else: + assert res == t + if bnd in ("nccl", "gloo", "mpi"): with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"): res = idist.all_gather(t, group="abc") diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index c8f51021ca1..fda3e1126cc 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -3,6 +3,7 @@ import pytest import torch import torch.distributed as dist +from packaging.version import Version import ignite.distributed as idist from ignite.distributed.utils import has_native_dist_support @@ -236,6 +237,7 @@ def test_idist_all_reduce_gloo(distributed_context_single_node_gloo): @pytest.mark.distributed @pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="dist.all_gather_object is not implemented") def test_idist_all_gather_nccl(distributed_context_single_node_nccl): device = idist.device() _test_distrib_all_gather(device) @@ -244,6 +246,7 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl): @pytest.mark.distributed @pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="dist.all_gather_object is not implemented") def test_idist_all_gather_gloo(distributed_context_single_node_gloo): device = idist.device() _test_distrib_all_gather(device)