-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Implement NumPy-like function torch.float_power() #44937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
90addf5
d7a1724
b1e179b
902fb40
6d43281
467c77b
c17a467
1de6da8
de66172
6d3e64d
922a77a
453287e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,48 @@ Tensor pow(Scalar base, const Tensor& exp) { | |
return native::pow_out(result, base, exp); | ||
} | ||
|
||
Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) { | ||
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? | ||
at::kComplexDouble : at::kDouble; | ||
TORCH_CHECK(result.scalar_type() == dtype, | ||
"output type ", result.scalar_type(), "is not the desired output type ", dtype); | ||
|
||
return at::pow_out(result, base.to(dtype), exp.to(dtype)); | ||
} | ||
|
||
Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) { | ||
return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device())); | ||
} | ||
|
||
Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) { | ||
return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp); | ||
} | ||
|
||
Tensor float_power(const Tensor& base, const Tensor& exp) { | ||
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; | ||
return at::pow(base.to(dtype), exp.to(dtype)); | ||
} | ||
|
||
Tensor float_power(const Tensor& base, Scalar exp) { | ||
|
||
return at::float_power(base, c10::scalar_to_tensor(exp, base.device())); | ||
} | ||
|
||
Tensor float_power(Scalar base, const Tensor& exp) { | ||
return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp); | ||
} | ||
Kiyosora marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
Tensor& float_power_(Tensor& base, const Tensor& exp) { | ||
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble; | ||
TORCH_CHECK(base.scalar_type() == dtype, | ||
"self tensor type ", base.scalar_type(), "is not the desired type ", dtype); | ||
|
||
return base.pow_(exp.to(dtype)); | ||
} | ||
|
||
Tensor& float_power_(Tensor& base, Scalar exp) { | ||
return base.float_power_(c10::scalar_to_tensor(exp, base.device())); | ||
} | ||
|
||
} // namespace native | ||
|
||
} // namespace at |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -6871,6 +6871,47 @@ | |||
CPU, CUDA: pow | ||||
SparseCPU, SparseCUDA: pow_sparse_scalar | ||||
|
||||
- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) | ||||
dispatch: | ||||
|
- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) |
Those entries should probably be moved next to the other pow entries, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once I remove the explicit dispatches, the following error comes out:
AssertionError: There's a formula for float_power(or its functional variant) in derivatives.yaml. It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA or DefaultBackend in native_functions.yaml. Please see https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword for instructions to choose the right dispatch keyword.
So, I guess we need the explicit dispatches here for autograd.
In addition, The improvement for pow_
is now progressing in the separate PR #46830.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way dispatch happens has actually changed. I think the correct dispatch for these is now Math:
instead of CPU, CUDA
since float_power is implemented using pow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that when using Math:
as dispatch, both autograd test and XLA test will suffer from the precision lack. Even if I directly call pow
without any dtype casting like this below, the problem still exists.
Tensor float_power(const Tensor& base, const Tensor& exp) {
return at::pow(base, exp);
}
Since using CPU, CUDA
can avoid from precision lack, maybe we should revert to it?
I am not familiar with autograd yet, maybe I have missed something... 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, can you elaborate on the lack of precision you're seeing, and how changing the dispatch can help with it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply, @mruberry
When I use Math
as dispatch, an assertEqual error occurs in test_autograd
, saying that the gradients calculated by in-place variant is inconsistent with the general, just like:
>>> a=torch.tensor([1.,2.,3.], dtype=torch.double, requires_grad=True)
>>> b=torch.tensor([1.,2.,3.], dtype=torch.double, requires_grad=True)
>>> grad=torch.randn_like(a).double()
>>> a.float_power(2.).backward(grad)
>>> a.grad
tensor([-4.0256, -1.6108, 1.2338], dtype=torch.float64)
>>> b.float_power_(2.).backward(grad)
>>> b.grad
tensor([-6.0385, -2.0134, 1.4394], dtype=torch.float64)
But in fact, the in-place variants usually not allow to calculating gradients, the original pow
is also doing as so.
>>> a.pow_(2.).backward(grad)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
And when I changed dispatch from Math
to CPU, CUDA
, making a define in tools/autograd/derivatives.yaml
(as we do in the previous version), The above abnormal phenomenon was eliminated.
It seems that there still not have any in-place variant used Math
as dispatch so far, so I doubt it may related with this phenomenon...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
This error is unrelated to the pow formula, it only happens because you modify your leaf inplace. Doing a.clone().pow_(2.)
should work just fine.
saying that the gradients calculated by in-place variant is inconsistent with the general
If you don't provide a formula directly in derivatives.yaml, you need to make sure to only ever call functions that do from your generic aten implementation. In particular, always call the at::
version and not native::
version of ops.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,6 +296,7 @@ Pointwise Ops | |
exp2 | ||
expm1 | ||
fix | ||
float_power | ||
floor | ||
floor_divide | ||
fmod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6205,6 +6205,102 @@ def test_pow(self, device): | |
torch.pow(m1, 1, out=out) | ||
self.assertEqual(out, m1) | ||
|
||
@dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False), | ||
Kiyosora marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
torch.testing.get_all_dtypes(include_bool=False)))) | ||
def test_float_power(self, device, dtypes): | ||
def to_np(value): | ||
if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16: | ||
return value.to(torch.float).cpu().numpy() | ||
return value.cpu().numpy() if isinstance(value, torch.Tensor) else value | ||
|
||
base_dtype = dtypes[0] | ||
Kiyosora marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
exp_dtype = dtypes[1] | ||
out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64 | ||
|
||
base = make_tensor((30,), device, base_dtype, low=1, high=100) | ||
# Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0 | ||
# Related: https://github.com/pytorch/pytorch/issues/48000 | ||
# base[0] = base[3] = base[7] = 0 | ||
exp = make_tensor((30,), device, exp_dtype, low=-2, high=2) | ||
exp[0] = exp[4] = exp[6] = 0 | ||
|
||
expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp))) | ||
|
||
exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2] | ||
complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] | ||
|
||
for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_): | ||
|
||
# Case of Tensor x Tensor | ||
if op is torch.Tensor.float_power_ and base_dtype != out_dtype: | ||
with self.assertRaisesRegex(RuntimeError, "is not the desired type"): | ||
op(base.clone(), exp) | ||
else: | ||
result = op(base.clone(), exp) | ||
self.assertEqual(expected, result) | ||
|
||
if op is torch.float_power: | ||
out = torch.empty_like(base).to(device=device, dtype=out_dtype) | ||
op(base, exp, out=out) | ||
self.assertEqual(expected, out) | ||
|
||
# Case of Tensor x Scalar | ||
for i in complex_exponents if exp_dtype.is_complex else exponents: | ||
out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64 | ||
expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) | ||
|
||
if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp: | ||
with self.assertRaisesRegex(RuntimeError, "is not the desired type"): | ||
op(base.clone(), i) | ||
else: | ||
result = op(base.clone(), i) | ||
self.assertEqual(expected_scalar_exp, result) | ||
|
||
if op is torch.float_power: | ||
|
||
out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp) | ||
op(base, i, out=out) | ||
self.assertEqual(expected_scalar_exp, out) | ||
|
||
# Case of Scalar x Tensor | ||
for i in complex_exponents if base_dtype.is_complex else exponents: | ||
out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64 | ||
expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp))) | ||
|
||
result = torch.float_power(i, exp) | ||
self.assertEqual(expected_scalar_base, result) | ||
|
||
out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base) | ||
torch.float_power(i, exp, out=out) | ||
self.assertEqual(expected_scalar_base, out) | ||
|
||
def test_float_power_exceptions(self, device): | ||
def _promo_helper(x, y): | ||
for i in (x, y): | ||
if type(i) == complex: | ||
return torch.complex128 | ||
elif type(i) == torch.Tensor and i.is_complex(): | ||
return torch.complex128 | ||
return torch.double | ||
|
||
test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25), | ||
(torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.)) | ||
for base, exp in test_cases: | ||
for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble): | ||
out = torch.empty(1, device=device, dtype=out_dtype) | ||
required_dtype = _promo_helper(base, exp) | ||
|
||
if out.dtype == required_dtype: | ||
torch.float_power(base, exp, out=out) | ||
else: | ||
with self.assertRaisesRegex(RuntimeError, "is not the desired output type"): | ||
torch.float_power(base, exp, out=out) | ||
|
||
if base.dtype == required_dtype: | ||
torch.Tensor.float_power_(base.clone(), exp) | ||
else: | ||
with self.assertRaisesRegex(RuntimeError, "is not the desired type"): | ||
torch.Tensor.float_power_(base.clone(), exp) | ||
|
||
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found') | ||
@onlyOnCPUAndCUDA | ||
@dtypes(torch.int8, torch.int16, torch.int32, torch.int64) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this Tensor x Scalar variant and the next Scalar x Tensor variant both call the prior Tensor x Tensor variant after wrapping the scalar in a tensor? That might make the code more maintainable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed! Thanks for advice!