Skip to content

Commit

Permalink
Update on "scatter_object_list API for c10d"
Browse files Browse the repository at this point in the history
Closes #23232. As part of addressing #23232, this PR adds support for scatter_object_list which is an API to scatter arbitrary picklable objects to all the other ranks.

The implementation approach follows a similar approach as #42189. The result of the `scatter` is stored as the first element of `scatter_object_output_list`, and the src rank is expected to provide an input list `scatter_object_input_list` which contains the objects to scatter.

Note that this API requires 1 broadcast and 2 scatters. This is because we must communicate the maximum object size to be scattered, which only the src rank knows about. After that, we also need to communicate the objects themselves as well as the true sizes of the object.

Note that the API is designed to match the tensor-based collectives other than supporting async_op. For now, it is a blocking call. If we see demand to support async_op, we will have to make more progress on merging work/future to support this.

It only works for Gloo because NCCL doesn't support scatter.

Differential Revision: [D23430686](https://our.internmc.facebook.com/intern/diff/D23430686/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D23430686/)!

[ghstack-poisoned]
  • Loading branch information
rohan-varma committed Dec 4, 2020
2 parents 137931f + 22d28da commit 370522c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -1616,7 +1616,7 @@ def scatter_object_list(
element will store the object scattered to this rank.
scatter_object_input_list (List[Any]): List of input objects to scatter.
Each object must be picklable. Only objects on the ``src`` rank will
be scattered, and the argument can be ``None`` for non-src ranks.
be scattered, and the argument can be ``None`` for non-src ranks.
src (int): Source rank from which to scatter
``scatter_object_input_list``.
group: (ProcessGroup, optional): The process group to work on.
Expand Down Expand Up @@ -1646,7 +1646,7 @@ def scatter_object_list(
"Expected argument scatter_object_output_list to be a list of size at least 1."
)

my_rank = get_rank()
my_rank = get_rank(group)
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[_object_to_tensor(obj) for obj in scatter_object_input_list]
Expand Down
5 changes: 3 additions & 2 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -4278,9 +4278,10 @@ def test_scatter_object_list(self):
if self.rank == src_rank
else [None for _ in collectives_object_test_list]
)
scatter_list = scatter_list[: int(os.environ["WORLD_SIZE"])]
world_size = dist.get_world_size()
scatter_list = scatter_list[: world_size]
i = 0
while len(scatter_list) < int(os.environ["WORLD_SIZE"]):
while len(scatter_list) < world_size:
scatter_list.append(scatter_list[i])
i += 1

Expand Down

0 comments on commit 370522c

Please sign in to comment.