Skip to content

Commit

Permalink
Implement NumPy-like function torch.float_power() (#44937)
Browse files Browse the repository at this point in the history
Summary:
- Related with #38349
- Implementing the NumPy-like function `torch.float_power()` .

Pull Request resolved: #44937

Reviewed By: ngimel

Differential Revision: D25192119

Pulled By: mruberry

fbshipit-source-id: 2e446b8e0c2825f045fe057e30c9419335557a05
  • Loading branch information
Kiyosora authored and facebook-github-bot committed Nov 28, 2020
1 parent 25ab39a commit 272f4db
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -552,6 +552,7 @@ _(aten, pixel_shuffle) \
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \
_(aten, float_power) \
_(aten, prelu) \
_(aten, prelu_backward) \
_(aten, prod) \
Expand Down
42 changes: 42 additions & 0 deletions aten/src/ATen/native/Pow.cpp
Expand Up @@ -81,6 +81,48 @@ Tensor pow(Scalar base, const Tensor& exp) {
return native::pow_out(result, base, exp);
}

Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"output type ", result.scalar_type(), "is not the desired output type ", dtype);

return at::pow_out(result, base.to(dtype), exp.to(dtype));
}

Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) {
return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device()));
}

Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) {
return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp);
}

Tensor float_power(const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
return at::pow(base.to(dtype), exp.to(dtype));
}

Tensor float_power(const Tensor& base, Scalar exp) {
return at::float_power(base, c10::scalar_to_tensor(exp, base.device()));
}

Tensor float_power(Scalar base, const Tensor& exp) {
return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp);
}

Tensor& float_power_(Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(base.scalar_type() == dtype,
"self tensor type ", base.scalar_type(), "is not the desired type ", dtype);

return base.pow_(exp.to(dtype));
}

Tensor& float_power_(Tensor& base, Scalar exp) {
return base.float_power_(c10::scalar_to_tensor(exp, base.device()));
}

} // namespace native

} // namespace at
41 changes: 41 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -6872,6 +6872,47 @@
CPU, CUDA: pow
SparseCPU, SparseCUDA: pow_sparse_scalar

- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Math: float_power_out

- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
Math: float_power

- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Math: float_power_out

- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor
use_c10_dispatcher: full
dispatch:
Math: float_power

- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Math: float_power_out

- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
Math: float_power

- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
Math: float_power_

- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
Math: float_power_

- func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
variants: method
dispatch:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Expand Up @@ -322,6 +322,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: fliplr
.. automethod:: flipud
.. automethod:: float
.. automethod:: float_power
.. automethod:: float_power_
.. automethod:: floor
.. automethod:: floor_
.. automethod:: floor_divide
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -296,6 +296,7 @@ Pointwise Ops
exp2
expm1
fix
float_power
floor
floor_divide
fmod
Expand Down
96 changes: 96 additions & 0 deletions test/test_torch.py
Expand Up @@ -6205,6 +6205,102 @@ def test_pow(self, device):
torch.pow(m1, 1, out=out)
self.assertEqual(out, m1)

@dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False),
torch.testing.get_all_dtypes(include_bool=False))))
def test_float_power(self, device, dtypes):
def to_np(value):
if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
return value.to(torch.float).cpu().numpy()
return value.cpu().numpy() if isinstance(value, torch.Tensor) else value

base_dtype = dtypes[0]
exp_dtype = dtypes[1]
out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64

base = make_tensor((30,), device, base_dtype, low=1, high=100)
# Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0
# Related: https://github.com/pytorch/pytorch/issues/48000
# base[0] = base[3] = base[7] = 0
exp = make_tensor((30,), device, exp_dtype, low=-2, high=2)
exp[0] = exp[4] = exp[6] = 0

expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))

exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j]

for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_):

# Case of Tensor x Tensor
if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
op(base.clone(), exp)
else:
result = op(base.clone(), exp)
self.assertEqual(expected, result)

if op is torch.float_power:
out = torch.empty_like(base).to(device=device, dtype=out_dtype)
op(base, exp, out=out)
self.assertEqual(expected, out)

# Case of Tensor x Scalar
for i in complex_exponents if exp_dtype.is_complex else exponents:
out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64
expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))

if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
op(base.clone(), i)
else:
result = op(base.clone(), i)
self.assertEqual(expected_scalar_exp, result)

if op is torch.float_power:
out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp)
op(base, i, out=out)
self.assertEqual(expected_scalar_exp, out)

# Case of Scalar x Tensor
for i in complex_exponents if base_dtype.is_complex else exponents:
out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64
expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))

result = torch.float_power(i, exp)
self.assertEqual(expected_scalar_base, result)

out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
torch.float_power(i, exp, out=out)
self.assertEqual(expected_scalar_base, out)

def test_float_power_exceptions(self, device):
def _promo_helper(x, y):
for i in (x, y):
if type(i) == complex:
return torch.complex128
elif type(i) == torch.Tensor and i.is_complex():
return torch.complex128
return torch.double

test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25),
(torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.))
for base, exp in test_cases:
for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
out = torch.empty(1, device=device, dtype=out_dtype)
required_dtype = _promo_helper(base, exp)

if out.dtype == required_dtype:
torch.float_power(base, exp, out=out)
else:
with self.assertRaisesRegex(RuntimeError, "is not the desired output type"):
torch.float_power(base, exp, out=out)

if base.dtype == required_dtype:
torch.Tensor.float_power_(base.clone(), exp)
else:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
torch.Tensor.float_power_(base.clone(), exp)

@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@onlyOnCPUAndCUDA
@dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
Expand Down
14 changes: 14 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -2671,6 +2671,20 @@ def callable(a, b) -> number
In-place version of :meth:`~Tensor.pow`
""")

add_docstr_all('float_power',
r"""
float_power(exponent) -> Tensor
See :func:`torch.float_power`
""")

add_docstr_all('float_power_',
r"""
float_power_(exponent) -> Tensor
In-place version of :meth:`~Tensor.float_power`
""")

add_docstr_all('prod',
r"""
prod(dim=None, keepdim=False, dtype=None) -> Tensor
Expand Down
40 changes: 40 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -6367,6 +6367,46 @@ def merge_dicts(*dicts):
tensor([ 2., 4., 8., 16.])
""".format(**common_args))

add_docstr(torch.float_power,
r"""
float_power(input, exponent, *, out=None) -> Tensor
Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision.
If neither input is complex returns a ``torch.float64`` tensor,
and if one or more inputs is complex returns a ``torch.complex128`` tensor.
.. note::
This function always computes in double precision, unlike :func:`torch.pow`,
which implements more typical :ref:`type promotion <type-promotion-doc>`.
This is useful when the computation needs to be performed in a wider or more precise dtype,
or the results of the computation may contain fractional values not representable in the input dtypes,
like when an integer base is raised to a negative integer exponent.
Args:
input (Tensor or Number): the base value(s)
exponent (Tensor or Number): the exponent value(s)
Keyword args:
{out}
Example::
>>> a = torch.randint(10, (4,))
>>> a
tensor([6, 4, 7, 1])
>>> torch.float_power(a, 2)
tensor([36., 16., 49., 1.], dtype=torch.float64)
>>> a = torch.arange(1, 5)
>>> a
tensor([ 1, 2, 3, 4])
>>> exp = torch.tensor([2, -3, 4, -5])
>>> exp
tensor([ 2, -3, 4, -5])
>>> torch.float_power(a, exp)
tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64)
""".format(**common_args))

add_docstr(torch.prod,
r"""
prod(input, *, dtype=None) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -366,6 +366,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
torch.floor: lambda input, out=None: -1,
torch.floor_divide: lambda input, other: -1,
torch.float_power: lambda input, exponent, out=None: -1,
torch.fmod: lambda input, other, out=None: -1,
torch.frac: lambda input, out=None: -1,
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
Expand Down
8 changes: 8 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -694,6 +694,14 @@ def method_tests():
('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14,), 'complex_scalar_constant', (True,)),
('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_imaginary_exponent', (True,)),
('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), ''),
('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'),
('float_power', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
('float_power', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'),
('float_power', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'),
('float_power', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
('float_power', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'),
('float_power', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]),
('transpose', (), (0, 0), 'scalar', (False,)),
('transpose', (1,), (0, 0), '1d', (False,)),
Expand Down

0 comments on commit 272f4db

Please sign in to comment.