diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index b0fc0e0fd98e..13a950024af9 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1616,7 +1616,7 @@ def scatter_object_list( element will store the object scattered to this rank. scatter_object_input_list (List[Any]): List of input objects to scatter. Each object must be picklable. Only objects on the ``src`` rank will - be scattered, and the argument can be ``None`` for non-src ranks. + be scattered, and the argument can be ``None`` for non-src ranks. src (int): Source rank from which to scatter ``scatter_object_input_list``. group: (ProcessGroup, optional): The process group to work on. @@ -1646,7 +1646,7 @@ def scatter_object_list( "Expected argument scatter_object_output_list to be a list of size at least 1." ) - my_rank = get_rank() + my_rank = get_rank(group) if my_rank == src: tensor_list, tensor_sizes = zip( *[_object_to_tensor(obj) for obj in scatter_object_input_list] diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 1312ac409551..897e985576c2 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -4278,9 +4278,10 @@ def test_scatter_object_list(self): if self.rank == src_rank else [None for _ in collectives_object_test_list] ) - scatter_list = scatter_list[: int(os.environ["WORLD_SIZE"])] + world_size = dist.get_world_size() + scatter_list = scatter_list[: world_size] i = 0 - while len(scatter_list) < int(os.environ["WORLD_SIZE"]): + while len(scatter_list) < world_size: scatter_list.append(scatter_list[i]) i += 1