Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.pow when the scalar base is a complex number #45259

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion aten/src/ATen/native/Pow.cpp
Expand Up @@ -43,7 +43,9 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {
}

Tensor& pow_out(Tensor& result, Scalar base, const Tensor& exp) {
if (base.toDouble() == 1.0) {
if (base.isComplex() && base.toComplexDouble() == 1.0) {
result.resize_as_(exp).fill_(1);
} else if (!base.isComplex() && base.toDouble() == 1.0) {
result.resize_as_(exp).fill_(1);
} else {
native::pow_out(result, c10::scalar_to_tensor(base, exp.device()), exp);
Expand Down
9 changes: 9 additions & 0 deletions test/test_torch.py
Expand Up @@ -13716,6 +13716,15 @@ def test_float_scalar_pow_float_tensor(self, device):
for base in floats:
self._test_pow(base, tensor)

@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False)))
def test_complex_scalar_pow_tensor(self, device, dtype):
complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j]
tensor = torch.rand(100).to(dtype=dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For integer dtypes this resolves to a tensor full of zeros, which may not be the most interesting test case. We have a make_tensor function to generate a random tensor that would be nicer to use:

def make_tensor(size, device: torch.device, dtype: torch.dtype, *,
low, high, requires_grad: bool = False) -> torch.Tensor:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for pointing out this issue, and offering the kind tip of using the make_tensor function.
#47101 has been created to address this comment and special cases of zero exponents and the 1 + 0j base are added as well, please kindly help review.

for base in complexes:
self._test_pow(base, tensor)

@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
def test_tensor_pow_tensor(self, dev):
def rotate(l, n):
Expand Down