Skip to content

Commit

Permalink
Enhance the torch.pow testcase for the complex scalar base (#47101)
Browse files Browse the repository at this point in the history
Summary:
Related #45259

This PR is to address the #45259 (comment)

- leverage the `make_tensor`  function to generate a random tensor as the exponent, preventing the full zeros for the integer exponent.
- add some special cases for the zero exponents and the `1 + 0j` base.

Pull Request resolved: #47101

Reviewed By: mruberry

Differential Revision: D24682430

Pulled By: zou3519

fbshipit-source-id: f559dc0ba08f37ae070036fb25a52ede17a24149
  • Loading branch information
RockingJavaBean authored and facebook-github-bot committed Nov 2, 2020
1 parent 9b52654 commit 22b3d41
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions test/test_torch.py
Expand Up @@ -14001,10 +14001,11 @@ def test_cpu_tensor_pow_cuda_scalar_tensor(self, device):
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False)))
def test_complex_scalar_pow_tensor(self, device, dtype):
complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j]
tensor = torch.rand(100).to(dtype=dtype, device=device)
complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j, 1 + 0j]
exp = make_tensor((100,), device, dtype, low=-2, high=2)
exp[0] = exp[10] = exp[20] = 0
for base in complexes:
self._test_pow(base, tensor)
self._test_pow(base, exp)

@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
def test_tensor_pow_tensor(self, dev):
Expand Down

0 comments on commit 22b3d41

Please sign in to comment.