Skip to content

Commit

Permalink
trying to make pow work for tensor raised to the power of a scalar (#…
Browse files Browse the repository at this point in the history
…46185)

Summary:
Fixes #46037

I'm not sure this is the most performant solution, but this works:

torch.pow(cuda_tensor, 5) should work and worked before.
torch.pow(cuda_tensor, torch.tensor(5)), should work **and works now!**
torch.pow(cuda_tensor, torch.tensor((5,))), should NOT work and complain the tensors are on different devices and indeed continues to complain.

Pull Request resolved: #46185

Reviewed By: glaringlee, malfet

Differential Revision: D24257687

Pulled By: janeyx99

fbshipit-source-id: 2daf235d62ec5886d7c153da05445c2ec71dec98
  • Loading branch information
janeyx99 authored and facebook-github-bot committed Oct 13, 2020
1 parent 1a57b39 commit ad376f1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/native/Pow.cpp
Expand Up @@ -11,6 +11,9 @@ DEFINE_DISPATCH(pow_tensor_tensor_stub);
DEFINE_DISPATCH(pow_tensor_scalar_stub);

Tensor& pow_out(Tensor& result, const Tensor& base, const Tensor& exp) {
if (exp.dim() == 0 && base.dim() != 0) {
return native::pow_out(result, base, exp.item());
}
auto iter = TensorIterator::binary_op(result, base, exp);
pow_tensor_tensor_stub(iter.device_type(), iter);
return result;
Expand Down
8 changes: 8 additions & 0 deletions test/test_torch.py
Expand Up @@ -13849,6 +13849,14 @@ def test_float_scalar_pow_float_tensor(self, device):
for base in floats:
self._test_pow(base, tensor)

@onlyCUDA
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
def test_cuda_tensor_pow_scalar_tensor(self, device):
cuda_tensor = torch.randn((3, 3), device=device)
scalar_tensors = [torch.tensor(5), torch.tensor(-3), torch.tensor(1)]
for exp in scalar_tensors:
self._test_pow(cuda_tensor, exp)

@onlyOnCPUAndCUDA
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False)))
Expand Down

0 comments on commit ad376f1

Please sign in to comment.