Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ _(aten, pixel_shuffle) \
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \
_(aten, float_power) \
_(aten, prelu) \
_(aten, prelu_backward) \
_(aten, prod) \
Expand Down
42 changes: 42 additions & 0 deletions aten/src/ATen/native/Pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,48 @@ Tensor pow(Scalar base, const Tensor& exp) {
return native::pow_out(result, base, exp);
}

Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"output type ", result.scalar_type(), "is not the desired output type ", dtype);

return at::pow_out(result, base.to(dtype), exp.to(dtype));
}

Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this Tensor x Scalar variant and the next Scalar x Tensor variant both call the prior Tensor x Tensor variant after wrapping the scalar in a tensor? That might make the code more maintainable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed! Thanks for advice!

return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device()));
}

Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) {
return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp);
}

Tensor float_power(const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
return at::pow(base.to(dtype), exp.to(dtype));
}

Tensor float_power(const Tensor& base, Scalar exp) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question about reusing the previous Tensor x Tensor variant for the Tensor x Scalar variant here and the Scalar x Tensor variant below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed!

return at::float_power(base, c10::scalar_to_tensor(exp, base.device()));
}

Tensor float_power(Scalar base, const Tensor& exp) {
return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp);
}

Tensor& float_power_(Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(base.scalar_type() == dtype,
"self tensor type ", base.scalar_type(), "is not the desired type ", dtype);

return base.pow_(exp.to(dtype));
}

Tensor& float_power_(Tensor& base, Scalar exp) {
return base.float_power_(c10::scalar_to_tensor(exp, base.device()));
}

} // namespace native

} // namespace at
41 changes: 41 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6871,6 +6871,47 @@
CPU, CUDA: pow
SparseCPU, SparseCUDA: pow_sparse_scalar

- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need explicit dispatches? I think the system will work out which functions to call. Explicitly defining the dispatch will, perhaps surprisingly, prevent float_power from deriving its derivative if it's implemented as a composite operation (as suggested above).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pow also has an inplace variant, pow_:

- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)

Those entries should probably be moved next to the other pow entries, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once I remove the explicit dispatches, the following error comes out:

AssertionError: There's a formula for float_power(or its functional variant) in derivatives.yaml. It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA or DefaultBackend in native_functions.yaml. Please see https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword for instructions to choose the right dispatch keyword.

So, I guess we need the explicit dispatches here for autograd.

In addition, The improvement for pow_ is now progressing in the separate PR #46830.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way dispatch happens has actually changed. I think the correct dispatch for these is now Math: instead of CPU, CUDA since float_power is implemented using pow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that when using Math: as dispatch, both autograd test and XLA test will suffer from the precision lack. Even if I directly call pow without any dtype casting like this below, the problem still exists.

Tensor float_power(const Tensor& base, const Tensor& exp) {
  return at::pow(base, exp);
}

Since using CPU, CUDA can avoid from precision lack, maybe we should revert to it?
I am not familiar with autograd yet, maybe I have missed something... 😕

Copy link
Collaborator

@mruberry mruberry Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, can you elaborate on the lack of precision you're seeing, and how changing the dispatch can help with it?

Copy link
Contributor Author

@Kiyosora Kiyosora Nov 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply, @mruberry
When I use Math as dispatch, an assertEqual error occurs in test_autograd, saying that the gradients calculated by in-place variant is inconsistent with the general, just like:

>>> a=torch.tensor([1.,2.,3.], dtype=torch.double, requires_grad=True)
>>> b=torch.tensor([1.,2.,3.], dtype=torch.double, requires_grad=True)
>>> grad=torch.randn_like(a).double()
>>> a.float_power(2.).backward(grad)
>>> a.grad
tensor([-4.0256, -1.6108,  1.2338], dtype=torch.float64)
>>> b.float_power_(2.).backward(grad)
>>> b.grad
tensor([-6.0385, -2.0134,  1.4394], dtype=torch.float64)

But in fact, the in-place variants usually not allow to calculating gradients, the original pow is also doing as so.

>>> a.pow_(2.).backward(grad)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

And when I changed dispatch from Math to CPU, CUDA, making a define in tools/autograd/derivatives.yaml (as we do in the previous version), The above abnormal phenomenon was eliminated.
It seems that there still not have any in-place variant used Math as dispatch so far, so I doubt it may related with this phenomenon...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this explanation, this is extremely interesting. cc @ailzhang and @albanD to take a look, too.

Copy link
Collaborator

@albanD albanD Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

This error is unrelated to the pow formula, it only happens because you modify your leaf inplace. Doing a.clone().pow_(2.) should work just fine.

saying that the gradients calculated by in-place variant is inconsistent with the general

If you don't provide a formula directly in derivatives.yaml, you need to make sure to only ever call functions that do from your generic aten implementation. In particular, always call the at:: version and not native:: version of ops.

Math: float_power_out

- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
Math: float_power

- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Math: float_power_out

- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor
use_c10_dispatcher: full
dispatch:
Math: float_power

- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Math: float_power_out

- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
Math: float_power

- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
Math: float_power_

- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
Math: float_power_

- func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
variants: method
dispatch:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: fliplr
.. automethod:: flipud
.. automethod:: float
.. automethod:: float_power
.. automethod:: float_power_
.. automethod:: floor
.. automethod:: floor_
.. automethod:: floor_divide
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ Pointwise Ops
exp2
expm1
fix
float_power
floor
floor_divide
fmod
Expand Down
96 changes: 96 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6205,6 +6205,102 @@ def test_pow(self, device):
torch.pow(m1, 1, out=out)
self.assertEqual(out, m1)

@dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False),
torch.testing.get_all_dtypes(include_bool=False))))
def test_float_power(self, device, dtypes):
def to_np(value):
if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
return value.to(torch.float).cpu().numpy()
return value.cpu().numpy() if isinstance(value, torch.Tensor) else value

base_dtype = dtypes[0]
exp_dtype = dtypes[1]
out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64

base = make_tensor((30,), device, base_dtype, low=1, high=100)
# Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0
# Related: https://github.com/pytorch/pytorch/issues/48000
# base[0] = base[3] = base[7] = 0
exp = make_tensor((30,), device, exp_dtype, low=-2, high=2)
exp[0] = exp[4] = exp[6] = 0

expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))

exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j]

for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_):

# Case of Tensor x Tensor
if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
op(base.clone(), exp)
else:
result = op(base.clone(), exp)
self.assertEqual(expected, result)

if op is torch.float_power:
out = torch.empty_like(base).to(device=device, dtype=out_dtype)
op(base, exp, out=out)
self.assertEqual(expected, out)

# Case of Tensor x Scalar
for i in complex_exponents if exp_dtype.is_complex else exponents:
out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64
expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))

if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
op(base.clone(), i)
else:
result = op(base.clone(), i)
self.assertEqual(expected_scalar_exp, result)

if op is torch.float_power:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this part of the test for?

Copy link
Contributor Author

@Kiyosora Kiyosora Nov 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part test the special case for Scalar x Tensor variant, which only applies for torch.float_power (We won't have i.float_power or i.float_power_, since self is a Scalar).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good; thanks for the explanation

out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp)
op(base, i, out=out)
self.assertEqual(expected_scalar_exp, out)

# Case of Scalar x Tensor
for i in complex_exponents if base_dtype.is_complex else exponents:
out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64
expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))

result = torch.float_power(i, exp)
self.assertEqual(expected_scalar_base, result)

out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
torch.float_power(i, exp, out=out)
self.assertEqual(expected_scalar_base, out)

def test_float_power_exceptions(self, device):
def _promo_helper(x, y):
for i in (x, y):
if type(i) == complex:
return torch.complex128
elif type(i) == torch.Tensor and i.is_complex():
return torch.complex128
return torch.double

test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25),
(torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.))
for base, exp in test_cases:
for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
out = torch.empty(1, device=device, dtype=out_dtype)
required_dtype = _promo_helper(base, exp)

if out.dtype == required_dtype:
torch.float_power(base, exp, out=out)
else:
with self.assertRaisesRegex(RuntimeError, "is not the desired output type"):
torch.float_power(base, exp, out=out)

if base.dtype == required_dtype:
torch.Tensor.float_power_(base.clone(), exp)
else:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
torch.Tensor.float_power_(base.clone(), exp)

@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@onlyOnCPUAndCUDA
@dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
Expand Down
14 changes: 14 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2671,6 +2671,20 @@ def callable(a, b) -> number
In-place version of :meth:`~Tensor.pow`
""")

add_docstr_all('float_power',
r"""
float_power(exponent) -> Tensor

See :func:`torch.float_power`
""")

add_docstr_all('float_power_',
r"""
float_power_(exponent) -> Tensor

In-place version of :meth:`~Tensor.float_power`
""")

add_docstr_all('prod',
r"""
prod(dim=None, keepdim=False, dtype=None) -> Tensor
Expand Down
40 changes: 40 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6338,6 +6338,46 @@ def merge_dicts(*dicts):
tensor([ 2., 4., 8., 16.])
""".format(**common_args))

add_docstr(torch.float_power,
r"""
float_power(input, exponent, *, out=None) -> Tensor

Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision.
If neither input is complex returns a ``torch.float64`` tensor,
and if one or more inputs is complex returns a ``torch.complex128`` tensor.

.. note::
This function always computes in double precision, unlike :func:`torch.pow`,
which implements more typical :ref:`type promotion <type-promotion-doc>`.
This is useful when the computation needs to be performed in a wider or more precise dtype,
or the results of the computation may contain fractional values not representable in the input dtypes,
like when an integer base is raised to a negative integer exponent.

Args:
input (Tensor or Number): the base value(s)
exponent (Tensor or Number): the exponent value(s)

Keyword args:
{out}

Example::

>>> a = torch.randint(10, (4,))
>>> a
tensor([6, 4, 7, 1])
>>> torch.float_power(a, 2)
tensor([36., 16., 49., 1.], dtype=torch.float64)

>>> a = torch.arange(1, 5)
>>> a
tensor([ 1, 2, 3, 4])
>>> exp = torch.tensor([2, -3, 4, -5])
>>> exp
tensor([ 2, -3, 4, -5])
>>> torch.float_power(a, exp)
tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64)
""".format(**common_args))

add_docstr(torch.prod,
r"""
prod(input, *, dtype=None) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
torch.floor: lambda input, out=None: -1,
torch.floor_divide: lambda input, other: -1,
torch.float_power: lambda input, exponent, out=None: -1,
torch.fmod: lambda input, other, out=None: -1,
torch.frac: lambda input, out=None: -1,
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
Expand Down
8 changes: 8 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,14 @@ def method_tests():
('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14,), 'complex_scalar_constant', (True,)),
('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_imaginary_exponent', (True,)),
('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), ''),
('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'),
('float_power', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
('float_power', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'),
('float_power', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'),
('float_power', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
('float_power', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'),
('float_power', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]),
('transpose', (), (0, 0), 'scalar', (False,)),
('transpose', (1,), (0, 0), '1d', (False,)),
Expand Down