diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index a1eee28fef96..4a34b27b7fe4 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -552,6 +552,7 @@ _(aten, pixel_shuffle) \ _(aten, poisson) \ _(aten, polygamma) \ _(aten, pow) \ +_(aten, float_power) \ _(aten, prelu) \ _(aten, prelu_backward) \ _(aten, prod) \ diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index ca5d1848a4b8..bfc5f910e093 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -81,6 +81,48 @@ Tensor pow(Scalar base, const Tensor& exp) { return native::pow_out(result, base, exp); } +Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? + at::kComplexDouble : at::kDouble; + TORCH_CHECK(result.scalar_type() == dtype, + "output type ", result.scalar_type(), "is not the desired output type ", dtype); + + return at::pow_out(result, base.to(dtype), exp.to(dtype)); +} + +Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) { + return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device())); +} + +Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) { + return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp); +} + +Tensor float_power(const Tensor& base, const Tensor& exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; + return at::pow(base.to(dtype), exp.to(dtype)); +} + +Tensor float_power(const Tensor& base, Scalar exp) { + return at::float_power(base, c10::scalar_to_tensor(exp, base.device())); +} + +Tensor float_power(Scalar base, const Tensor& exp) { + return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp); +} + +Tensor& float_power_(Tensor& base, const Tensor& exp) { + auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; + TORCH_CHECK(base.scalar_type() == dtype, + "self tensor type ", base.scalar_type(), "is not the desired type ", dtype); + + return base.pow_(exp.to(dtype)); +} + +Tensor& float_power_(Tensor& base, Scalar exp) { + return base.float_power_(c10::scalar_to_tensor(exp, base.device())); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9a289cb38726..2d6e570d25c8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6872,6 +6872,47 @@ CPU, CUDA: pow SparseCPU, SparseCUDA: pow_sparse_scalar +- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: float_power_out + +- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + Math: float_power + +- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: float_power_out + +- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor + use_c10_dispatcher: full + dispatch: + Math: float_power + +- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: float_power_out + +- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + Math: float_power + +- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + use_c10_dispatcher: full + variants: method + dispatch: + Math: float_power_ + +- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + use_c10_dispatcher: full + variants: method + dispatch: + Math: float_power_ + - func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 730a3856c32c..d7b0af757d92 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -322,6 +322,8 @@ view of a storage and defines numeric operations on it. .. automethod:: fliplr .. automethod:: flipud .. automethod:: float + .. automethod:: float_power + .. automethod:: float_power_ .. automethod:: floor .. automethod:: floor_ .. automethod:: floor_divide diff --git a/docs/source/torch.rst b/docs/source/torch.rst index aab84fc05d79..4399e63c3b01 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -296,6 +296,7 @@ Pointwise Ops exp2 expm1 fix + float_power floor floor_divide fmod diff --git a/test/test_torch.py b/test/test_torch.py index f417d10ec471..86dd20241354 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6205,6 +6205,102 @@ def test_pow(self, device): torch.pow(m1, 1, out=out) self.assertEqual(out, m1) + @dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False), + torch.testing.get_all_dtypes(include_bool=False)))) + def test_float_power(self, device, dtypes): + def to_np(value): + if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16: + return value.to(torch.float).cpu().numpy() + return value.cpu().numpy() if isinstance(value, torch.Tensor) else value + + base_dtype = dtypes[0] + exp_dtype = dtypes[1] + out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64 + + base = make_tensor((30,), device, base_dtype, low=1, high=100) + # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0 + # Related: https://github.com/pytorch/pytorch/issues/48000 + # base[0] = base[3] = base[7] = 0 + exp = make_tensor((30,), device, exp_dtype, low=-2, high=2) + exp[0] = exp[4] = exp[6] = 0 + + expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp))) + + exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2] + complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] + + for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_): + + # Case of Tensor x Tensor + if op is torch.Tensor.float_power_ and base_dtype != out_dtype: + with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + op(base.clone(), exp) + else: + result = op(base.clone(), exp) + self.assertEqual(expected, result) + + if op is torch.float_power: + out = torch.empty_like(base).to(device=device, dtype=out_dtype) + op(base, exp, out=out) + self.assertEqual(expected, out) + + # Case of Tensor x Scalar + for i in complex_exponents if exp_dtype.is_complex else exponents: + out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64 + expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) + + if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp: + with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + op(base.clone(), i) + else: + result = op(base.clone(), i) + self.assertEqual(expected_scalar_exp, result) + + if op is torch.float_power: + out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp) + op(base, i, out=out) + self.assertEqual(expected_scalar_exp, out) + + # Case of Scalar x Tensor + for i in complex_exponents if base_dtype.is_complex else exponents: + out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64 + expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp))) + + result = torch.float_power(i, exp) + self.assertEqual(expected_scalar_base, result) + + out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base) + torch.float_power(i, exp, out=out) + self.assertEqual(expected_scalar_base, out) + + def test_float_power_exceptions(self, device): + def _promo_helper(x, y): + for i in (x, y): + if type(i) == complex: + return torch.complex128 + elif type(i) == torch.Tensor and i.is_complex(): + return torch.complex128 + return torch.double + + test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25), + (torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.)) + for base, exp in test_cases: + for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble): + out = torch.empty(1, device=device, dtype=out_dtype) + required_dtype = _promo_helper(base, exp) + + if out.dtype == required_dtype: + torch.float_power(base, exp, out=out) + else: + with self.assertRaisesRegex(RuntimeError, "is not the desired output type"): + torch.float_power(base, exp, out=out) + + if base.dtype == required_dtype: + torch.Tensor.float_power_(base.clone(), exp) + else: + with self.assertRaisesRegex(RuntimeError, "is not the desired type"): + torch.Tensor.float_power_(base.clone(), exp) + @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @onlyOnCPUAndCUDA @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index c11a5b455dbb..2d8533b512df 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2671,6 +2671,20 @@ def callable(a, b) -> number In-place version of :meth:`~Tensor.pow` """) +add_docstr_all('float_power', + r""" +float_power(exponent) -> Tensor + +See :func:`torch.float_power` +""") + +add_docstr_all('float_power_', + r""" +float_power_(exponent) -> Tensor + +In-place version of :meth:`~Tensor.float_power` +""") + add_docstr_all('prod', r""" prod(dim=None, keepdim=False, dtype=None) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 545578c390e3..3b6ee12e7a68 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -6367,6 +6367,46 @@ def merge_dicts(*dicts): tensor([ 2., 4., 8., 16.]) """.format(**common_args)) +add_docstr(torch.float_power, + r""" +float_power(input, exponent, *, out=None) -> Tensor + +Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. +If neither input is complex returns a ``torch.float64`` tensor, +and if one or more inputs is complex returns a ``torch.complex128`` tensor. + +.. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + +Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + +Keyword args: + {out} + +Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) +""".format(**common_args)) + add_docstr(torch.prod, r""" prod(input, *, dtype=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index 2c48e712d48b..eb863c74b6ae 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -366,6 +366,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1, torch.floor: lambda input, out=None: -1, torch.floor_divide: lambda input, other: -1, + torch.float_power: lambda input, exponent, out=None: -1, torch.fmod: lambda input, other, out=None: -1, torch.frac: lambda input, out=None: -1, torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 36e7fb6d0c7c..8850c3f7bf49 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -694,6 +694,14 @@ def method_tests(): ('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14,), 'complex_scalar_constant', (True,)), ('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_imaginary_exponent', (True,)), ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')), + ('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), ''), + ('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'), + ('float_power', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'), + ('float_power', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'), + ('float_power', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'), + ('float_power', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'), + ('float_power', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'), + ('float_power', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'), ('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]), ('transpose', (), (0, 0), 'scalar', (False,)), ('transpose', (1,), (0, 0), '1d', (False,)),