Skip to content

Commit

Permalink
Stop using c10::scalar_to_tensor in float_power. (#50105)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #50105

There should be no functional change here.

A couple of reasons here:
1) This function is generally an anti-pattern (#49758) and it is good to minimize its usage in the code base.
2) pow itself has a fair amount of smarts like not broadcasting scalar/tensor combinations and we should defer to it.

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D25786172

Pulled By: gchanan

fbshipit-source-id: 89de03aa0b900ce011a62911224a5441f15e331a
  • Loading branch information
gchanan authored and facebook-github-bot committed Jan 8, 2021
1 parent 55919a4 commit 88bd69b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
50 changes: 38 additions & 12 deletions aten/src/ATen/native/Pow.cpp
Expand Up @@ -28,7 +28,7 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {

auto common_dtype = at::result_type(base, exp);
TORCH_CHECK(at::can_cast(common_dtype, result.scalar_type()),
"result type ", common_dtype, "can't be cast to the desired output type ",
"result type ", common_dtype, " can't be cast to the desired output type ",
result.scalar_type());

if (exp.equal(0.0)) {
Expand Down Expand Up @@ -83,42 +83,68 @@ Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"output type ", result.scalar_type(), "is not the desired output type ", dtype);
"the output given to float_power has dtype ", result.scalar_type(),
" but the operation's result requires dtype ", dtype);

return at::pow_out(result, base.to(dtype), exp.to(dtype));
}

Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) {
return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device()));
auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"the output given to float_power has dtype ", result.scalar_type(),
" but the operation's result requires dtype ", dtype);

// Note: need the casts inside the ternary because conversion functions return e.g. c10::complex,
// which causes a complex scalar to always be returned.
exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
return at::pow_out(result, base.to(dtype), exp);
}

Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) {
return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp);
}
auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"the output given to float_power has dtype ", result.scalar_type(),
" but the operation's result requires dtype ", dtype);

Tensor float_power(const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
return at::pow(base.to(dtype), exp.to(dtype));
base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
return at::pow_out(result, base, exp.to(dtype));
}

Tensor float_power(const Tensor& base, Scalar exp) {
return at::float_power(base, c10::scalar_to_tensor(exp, base.device()));
auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
return at::pow(base.to(dtype), exp);
}

Tensor float_power(Scalar base, const Tensor& exp) {
return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp);
auto dtype = (at::isComplexType(exp.scalar_type()) || base.isComplex()) ? at::kComplexDouble : at::kDouble;
base = (dtype == at::kComplexDouble) ? Scalar(base.toComplexDouble()) : Scalar(base.toDouble());
return at::pow(base, exp.to(dtype));
}

Tensor float_power(const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
return at::pow(base.to(dtype), exp.to(dtype));
}

Tensor& float_power_(Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(base.scalar_type() == dtype,
"self tensor type ", base.scalar_type(), "is not the desired type ", dtype);
"the base given to float_power_ has dtype ", base.scalar_type(),
" but the operation's result requires dtype ", dtype);

return base.pow_(exp.to(dtype));
}

Tensor& float_power_(Tensor& base, Scalar exp) {
return base.float_power_(c10::scalar_to_tensor(exp, base.device()));
auto dtype = (at::isComplexType(base.scalar_type()) || exp.isComplex()) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(base.scalar_type() == dtype,
"the base given to float_power_ has dtype ", base.scalar_type(),
" but the operation's result requires dtype ", dtype);

exp = (dtype == at::kComplexDouble) ? Scalar(exp.toComplexDouble()) : Scalar(exp.toDouble());
return base.pow_(exp);
}

} // namespace native
Expand Down
8 changes: 4 additions & 4 deletions test/test_binary_ufuncs.py
Expand Up @@ -2424,7 +2424,7 @@ def to_np(value):

# Case of Tensor x Tensor
if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"):
op(base.clone(), exp)
else:
result = op(base.clone(), exp)
Expand All @@ -2441,7 +2441,7 @@ def to_np(value):
expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))

if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"):
op(base.clone(), i)
else:
result = op(base.clone(), i)
Expand Down Expand Up @@ -2483,13 +2483,13 @@ def _promo_helper(x, y):
if out.dtype == required_dtype:
torch.float_power(base, exp, out=out)
else:
with self.assertRaisesRegex(RuntimeError, "is not the desired output type"):
with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"):
torch.float_power(base, exp, out=out)

if base.dtype == required_dtype:
torch.Tensor.float_power_(base.clone(), exp)
else:
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"):
torch.Tensor.float_power_(base.clone(), exp)

@skipIf(not TEST_SCIPY, "Scipy required for the test.")
Expand Down

0 comments on commit 88bd69b

Please sign in to comment.