From 02d89f9f1d7f32ebf7ec509d5c14b2f39690997a Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Fri, 4 Dec 2020 18:40:24 -0800 Subject: [PATCH] scatter_object_list API for c10d (#43930) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43930 Closes #23232. As part of addressing #23232, this PR adds support for scatter_object_list which is an API to scatter arbitrary picklable objects to all the other ranks. The implementation approach follows a similar approach as https://github.com/pytorch/pytorch/pull/42189. The result of the `scatter` is stored as the first element of `scatter_object_output_list`, and the src rank is expected to provide an input list `scatter_object_input_list` which contains the objects to scatter. Note that this API requires 1 broadcast and 2 scatters. This is because we must communicate the maximum object size to be scattered, which only the src rank knows about. After that, we also need to communicate the objects themselves as well as the true sizes of the object. Note that the API is designed to match the tensor-based collectives other than supporting async_op. For now, it is a blocking call. If we see demand to support async_op, we will have to make more progress on merging work/future to support this. It only works for Gloo because NCCL doesn't support scatter. ghstack-source-id: 117904065 Reviewed By: mrshenli Differential Revision: D23430686 fbshipit-source-id: f033b89cd82dadd194f2b036312a98423449c26b --- torch/distributed/distributed_c10d.py | 84 +++++++++++++++++++ .../_internal/distributed/distributed_test.py | 29 +++++++ 2 files changed, 113 insertions(+) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 48b132811839..13a950024af9 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1601,6 +1601,90 @@ def broadcast_object_list(object_list, src, group=group.WORLD): object_list[i] = _tensor_to_object(obj_view, obj_size) +def scatter_object_list( + scatter_object_output_list, scatter_object_input_list, src=0, group=group.WORLD +): + """ + Scatters picklable objects in ``scatter_object_input_list`` to the whole + group. Similar to :func:`scatter`, but Python objects can be passed in. On + each rank, the scattered object will be stored as the first element of + ``scatter_object_output_list``. Note that all objects in + ``scatter_object_input_list`` must be picklable in order to be scattered. + + Arguments: + scatter_object_output_list (List[Any]): Non-empty list whose first + element will store the object scattered to this rank. + scatter_object_input_list (List[Any]): List of input objects to scatter. + Each object must be picklable. Only objects on the ``src`` rank will + be scattered, and the argument can be ``None`` for non-src ranks. + src (int): Source rank from which to scatter + ``scatter_object_input_list``. + group: (ProcessGroup, optional): The process group to work on. + + Returns: + ``None``. If rank is part of the group, ``scatter_object_output_list`` + will have its first element set to the scattered object for this rank. + + .. note:: Note that this API differs slightly from the scatter collective + since it does not provide an ``async_op`` handle and thus will be a + blocking call. + + .. warning:: + :func:`scatter_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + """ + if _rank_not_in_group(group): + return + + if ( + not isinstance(scatter_object_output_list, list) + or len(scatter_object_output_list) < 1 + ): + raise RuntimeError( + "Expected argument scatter_object_output_list to be a list of size at least 1." + ) + + my_rank = get_rank(group) + if my_rank == src: + tensor_list, tensor_sizes = zip( + *[_object_to_tensor(obj) for obj in scatter_object_input_list] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + obj_tensor_size = torch.LongTensor([0]) + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + if my_rank == src: + max_tensor_size = max(tensor_sizes) + for tensor in tensor_list: + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.LongTensor([0]) + broadcast(max_tensor_size, src=src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.ByteTensor(max_tensor_size.item()) + scatter( + output_tensor, + scatter_list=None if my_rank != src else tensor_list, + src=src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + scatter( + obj_tensor_size, + scatter_list=None if my_rank != src else tensor_sizes, + src=src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size) + + def all_gather(tensor_list, tensor, group=group.WORLD, diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 943eb24a0b5e..cbe8e9d630bf 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -4264,3 +4264,32 @@ def forward(self, x): ) if i == 1 else suppress(): loss = model(random_input).sum() loss.backward() + + @require_backend({"gloo"}) + @unittest.skipIf(BACKEND == "nccl", "NCCL does not support scatter") + def test_scatter_object_list(self): + src_rank = 0 + scatter_list = ( + collectives_object_test_list + if self.rank == src_rank + else [None for _ in collectives_object_test_list] + ) + world_size = dist.get_world_size() + scatter_list = scatter_list[: world_size] + i = 0 + while len(scatter_list) < world_size: + scatter_list.append(scatter_list[i]) + i += 1 + + output_obj_list = [None] + dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) + self.assertEqual( + output_obj_list[0], + collectives_object_test_list[self.rank % len(collectives_object_test_list)], + ) + # Ensure errors are raised upon incorrect arguments. + with self.assertRaisesRegex( + RuntimeError, + "Expected argument scatter_object_output_list to be a list of size at least 1.", + ): + dist.scatter_object_list([], scatter_list, src=src_rank)