Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix auto exponent issue for torch.pow #49809

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 2 additions & 4 deletions aten/src/ATen/native/Pow.cpp
Expand Up @@ -31,11 +31,9 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {
"result type ", common_dtype, "can't be cast to the desired output type ",
result.scalar_type());

auto exponent = (exp.isComplex()) ? exp.toComplexDouble() : exp.toDouble();

if (exponent == 0.0) {
if (exp.equal(0.0)) {
result.resize_as_(base).fill_(1);
} else if (exponent == 1.0) {
} else if (exp.equal(1.0)) {
result.resize_as_(base).copy_(base);
} else {
auto iter = TensorIterator::unary_op(result, base.to(common_dtype));
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/PowKernel.cpp
Expand Up @@ -63,7 +63,7 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
);
} else if (exp == -0.5) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
[](scalar_t base) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return 1.0 / std::sqrt(base);
},
[](Vec base) -> Vec { return base.rsqrt(); }
Expand Down
20 changes: 20 additions & 0 deletions aten/src/ATen/test/scalar_test.cpp
Expand Up @@ -138,3 +138,23 @@ TEST(TestScalar, TestConj) {
ASSERT_EQ(float_scalar.conj().toDouble(), 3.0);
ASSERT_EQ(complex_scalar.conj().toComplexDouble(), c10::complex<double>(2.3, -3.5));
}

TEST(TestScalar, TestEqual) {
ASSERT_FALSE(Scalar(1.0).equal(false));
ASSERT_FALSE(Scalar(1.0).equal(true));
ASSERT_FALSE(Scalar(true).equal(1.0));
ASSERT_TRUE(Scalar(true).equal(true));

ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 5.0}).equal(c10::complex<double>{2.0, 5.0}));
ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 0}).equal(2.0));
ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 0}).equal(2));

ASSERT_TRUE(Scalar(2.0).equal(c10::complex<double>{2.0, 0.0}));
ASSERT_FALSE(Scalar(2.0).equal(c10::complex<double>{2.0, 4.0}));
ASSERT_FALSE(Scalar(2.0).equal(3.0));
ASSERT_TRUE(Scalar(2.0).equal(2));

ASSERT_TRUE(Scalar(2).equal(c10::complex<double>{2.0, 0}));
ASSERT_TRUE(Scalar(2).equal(2));
ASSERT_TRUE(Scalar(2).equal(2.0));
}
10 changes: 10 additions & 0 deletions c10/core/Scalar.cpp
Expand Up @@ -21,4 +21,14 @@ Scalar Scalar::conj() const {
}
}

Scalar Scalar::log() const {
if (isComplex()) {
return std::log(v.z);
} else if (isFloatingPoint()) {
return std::log(v.d);
} else {
return std::log(v.i);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a bool case to be worried about here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no because scalar only stores double, complex, or int. In case a scalar is initialized with bool, it saves a flag indicating it is bool and saves the value as an int

}
}

} // namespace c10
39 changes: 39 additions & 0 deletions c10/core/Scalar.h
Expand Up @@ -88,6 +88,45 @@ class C10_API Scalar {

Scalar operator-() const;
Scalar conj() const;
Scalar log() const;

template<typename T, typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
bool equal(T num) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When comparing a Scalar against a C++ number should people always use equal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very familiar with c++ semantic, why is it different from operator==?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To compare a scalar to a number, we first have to obtain the stored value in the scalar which usually is done by .toDouble() but with the possibility that the scalar can also be complex, we can't always do that anymore. With the new equal method, that's automatically taken care of

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline conversation: add a special case for bool

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho you cannot template the operator== the same way as you do with equal?

Given that using .equal() and not == is a (small) constraint I'm just trying to understand why we actually need to do it. But this is not blocking for this PR!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operator== can be overloaded to compare two objects of the same type (in this case, scalars), but not objects of different types (a scalar and a number). If we do add this overload, we'd first have to convert the number to a scalar and then compare it with the scalar object in question. But that, to me, seems like an overkill

if (isComplex()) {
auto val = v.z;
return (val.real() == num) && (val.imag() == T());
} else if (isFloatingPoint()) {
return v.d == num;
} else if (isIntegral(/*includeBool=*/false)) {
return v.i == num;
} else {
// boolean scalar does not equal to a non boolean value
return false;
}
}

template<typename T, typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
bool equal(T num) const {
if (isComplex()) {
return v.z == num;
} else if (isFloatingPoint()) {
return (v.d == num.real()) && (num.imag() == T());
} else if (isIntegral(/*includeBool=*/false)) {
return (v.i == num.real()) && (num.imag() == T());
} else {
// boolean scalar does not equal to a non boolean value
return false;
}
}

bool equal(bool num) const {
if (isBoolean()) {
return static_cast<bool>(v.i) == num;
} else {
return false;
}
}

ScalarType type() const {
if (isComplex()) {
return ScalarType::ComplexDouble;
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/api/autograd.cpp
Expand Up @@ -175,7 +175,7 @@ TEST(AutogradAPITests, AnomalyMode) {
auto y = x.pow(1.5);
auto gr =
grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
ASSERT_THROWS_WITH(grad({gr[0]}, {x});, "returned nan");
ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan");
auto msgs = warnings.messages();
ASSERT_EQ(msgs.size(), 2);
ASSERT_TRUE(
Expand Down
17 changes: 8 additions & 9 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -204,12 +204,12 @@ Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> &
return norm_backward(grad, self, p_, norm);
}

Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
auto exponent = (exponent_.isComplex()) ? exponent_.toComplexDouble() : exponent_.toDouble();
if (exponent == 0.0) {
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) {
if (exponent.equal(0.0)) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
auto out = grad * (exponent * self.pow(exponent - 1)).conj();
auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); };
Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble());
return handle_r_to_c(self, out);
}
}
Expand Down Expand Up @@ -242,9 +242,8 @@ Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& expo
}

Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) {
auto base_ = base.isComplex() ? base.toComplexDouble() : base.toDouble();
auto grad_lambda = [](auto a, auto b) { return (a * std::log(b)).conj(); };
if (base_ == 0.0) {
auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); };
if (base.equal(0.0)) {
auto cond = [](auto exp) {
if (exp.is_complex()) {
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
Expand All @@ -254,10 +253,10 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp
};
auto out = grad * at::where(cond(exponent),
at::zeros({}, grad.options()),
grad_lambda(result, base_));
grad_lambda(result, base));
return handle_r_to_c(exponent, out);
} else {
auto out = grad * grad_lambda(result, base_);
auto out = grad * grad_lambda(result, base);
return handle_r_to_c(exponent, out);
}
}
Expand Down