Skip to content

Commit

Permalink
distributed utils - add test for pg wrapper (#644)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #644

Add test for pg wrapper

Reviewed By: JKSenthil

Differential Revision: D51930000

fbshipit-source-id: b05b7192686daf8030bba657411bfc536890092f
  • Loading branch information
galrotem authored and facebook-github-bot committed Dec 7, 2023
1 parent 5bc8861 commit 2c6a7c8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
22 changes: 22 additions & 0 deletions tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_process_group_backend_from_device,
get_tcp_init_method,
get_world_size,
PGWrapper,
rank_zero_fn,
revert_sync_batchnorm,
sync_bool,
Expand Down Expand Up @@ -437,3 +438,24 @@ def test_get_tcp_init_method(self) -> None:
self.assertEqual(url_qs["world_size"], [str(world_size)])
self.assertIn("rank", url_qs)
self.assertEqual(url_qs["rank"], [str(rank)])

def test_pg_wrapper_scatter_object_list_gloo(self) -> None:
config = get_pet_launch_config(4)
launcher.elastic_launch(
config, entrypoint=self._test_pg_wrapper_scatter_object_list
)

@classmethod
def _test_pg_wrapper_scatter_object_list(
cls,
) -> None:
dist.init_process_group("gloo")
pg_wrapper = PGWrapper(dist.group.WORLD)
output_list = [None] * 4
pg_wrapper.scatter_object_list(
output_list=output_list,
input_list=[1, 2, 3, 4] if get_local_rank() == 0 else [None] * 4,
src=0,
)
tc = unittest.TestCase()
tc.assertEqual(output_list[0], get_local_rank() + 1)
25 changes: 24 additions & 1 deletion tests/utils/test_distributed_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import torch
import torch.distributed as dist
from torchtnt.utils.device import get_device_from_env
from torchtnt.utils.distributed import all_gather_tensors
from torchtnt.utils.distributed import all_gather_tensors, get_local_rank, PGWrapper
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import spawn_multi_process


Expand Down Expand Up @@ -41,3 +42,25 @@ def _test_ddp_gather_uneven_tensors_multidim_nccl() -> None:
val = result[idx]
assert val.shape == (idx + 1, 4 - idx)
assert (val == 1).all()

def test_pg_wrapper_scatter_object_list_nccl(self) -> None:
spawn_multi_process(
2,
"nccl",
self._test_pg_wrapper_scatter_object_list,
)

@classmethod
def _test_pg_wrapper_scatter_object_list(
cls,
) -> None:
init_from_env()
pg_wrapper = PGWrapper(dist.group.WORLD)
output_list = [None] * 2
pg_wrapper.scatter_object_list(
output_list=output_list,
input_list=[1, 2] if get_local_rank() == 0 else [None] * 2,
src=0,
)
tc = unittest.TestCase()
tc.assertEqual(output_list[0], get_local_rank() + 1)

0 comments on commit 2c6a7c8

Please sign in to comment.