Skip to content

Commit

Permalink
Add test for complex
Browse files Browse the repository at this point in the history
  • Loading branch information
Rodrigo Kumpera committed Sep 13, 2022
1 parent b569ff8 commit 7d31ac2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
19 changes: 19 additions & 0 deletions test/distributed/test_c10d_common.py
Expand Up @@ -1204,6 +1204,25 @@ def _test_tensor_dtype_mismatch(self, backend):
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.scatter(tensor_h, tensor_list)

def _test_tensor_dtype_complex(self, backend):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)

tensor = torch.rand(2, device=self.device)
tensor_c = torch.view_as_complex(tensor)
tensor_list = [torch.rand(2, device=self.device) for _ in range(self.world_size)]
tensor_list_c = list(tensor_list)
tensor_list_c[1] = torch.view_as_complex(tensor_list_c[1])

dist.all_gather(tensor_list, tensor)
dist.all_gather(tensor_list, tensor_c)
dist.all_gather(tensor_list_c, tensor)
dist.all_gather(tensor_list_c, tensor_c)

class CommTest(AbstractCommTest, MultiProcessTestCase):
def setUp(self):
Expand Down
5 changes: 5 additions & 0 deletions test/distributed/test_c10d_gloo.py
Expand Up @@ -2353,6 +2353,11 @@ def test_gloo_rank_membership(self):
def test_tensor_dtype_mismatch(self):
self._test_tensor_dtype_mismatch(backend="gloo")

@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_tensor_dtype_complex(self):
self._test_tensor_dtype_complex(backend="gloo")

class CompilerTest(test_c10d_common.CompilerTest):

@property
Expand Down
5 changes: 5 additions & 0 deletions test/distributed/test_c10d_nccl.py
Expand Up @@ -2816,6 +2816,11 @@ def test_nncl_rank_membership(self):
def test_tensor_dtype_mismatch(self):
self._test_tensor_dtype_mismatch(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_tensor_dtype_complex(self):
self._test_tensor_dtype_complex(backend="nccl")


class CompilerTest(test_c10d_common.CompilerTest):

Expand Down

0 comments on commit 7d31ac2

Please sign in to comment.