From f3765e9a06838a4fea772361a0ea8fe7b455b686 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 13 Oct 2020 07:16:02 -0700 Subject: [PATCH] fix vmap test and remove fractional values from test case --- aten/src/ATen/native/Pow.cpp | 2 +- test/test_torch.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 8b6ad52d34f9..d65404bad361 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -11,7 +11,7 @@ DEFINE_DISPATCH(pow_tensor_tensor_stub); DEFINE_DISPATCH(pow_tensor_scalar_stub); Tensor& pow_out(Tensor& result, const Tensor& base, const Tensor& exp) { - if (exp.dim() == 0) { + if (exp.dim() == 0 && base.dim() != 0) { return native::pow_out(result, base, exp.item()); } auto iter = TensorIterator::binary_op(result, base, exp); diff --git a/test/test_torch.py b/test/test_torch.py index 48bb4e87ae67..59de6a6feb6d 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -13793,9 +13793,9 @@ def test_float_scalar_pow_float_tensor(self, device): @onlyCUDA @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') def test_cuda_tensor_pow_scalar_tensor(self, device): - cuda_tensor = torch.randn((3, 3), device='cuda') - scalars_tensors = [torch.tensor(5), torch.tensor(4.2), torch.tensor(-0.5)] - for exp in scalars_tensors: + cuda_tensor = torch.randn((3, 3), device=device) + scalar_tensors = [torch.tensor(5), torch.tensor(-3), torch.tensor(1)] + for exp in scalar_tensors: self._test_pow(cuda_tensor, exp) @onlyOnCPUAndCUDA