Skip to content

Commit

Permalink
Update on "Fix auto exponent issue for torch.pow"
Browse files Browse the repository at this point in the history
Fixes pytorch/xla#2688 #46936

[ghstack-poisoned]
  • Loading branch information
anjali411 committed Dec 29, 2020
1 parent c0a1675 commit eebd3a9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
25 changes: 18 additions & 7 deletions aten/src/ATen/test/scalar_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

using std::cout;
using namespace at;
using namespace c10;

constexpr auto Float = ScalarType::Float;

Expand Down Expand Up @@ -140,11 +141,21 @@ TEST(TestScalar, TestConj) {
}

TEST(TestScalar, TestEqual) {
ASSERT_EQ(Scalar(1.0).equal(false), false);
ASSERT_EQ(Scalar(1.0).equal(true), false);
ASSERT_EQ(Scalar(true).equal(1.0), false);
ASSERT_EQ(Scalar(c10::complex<double>{2.0, 0}).equal(2.0), true);
ASSERT_EQ(Scalar(2.0).equal(3.0), false);
// this fails
// ASSERT_EQ(Scalar(2).equal(2), true);
ASSERT_FALSE(Scalar(1.0).equal(false));
ASSERT_FALSE(Scalar(1.0).equal(true));
ASSERT_FALSE(Scalar(true).equal(1.0));
ASSERT_TRUE(Scalar(true).equal(true));

ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 5.0}).equal(c10::complex<double>{2.0, 5.0}));
ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 0}).equal(2.0));
ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 0}).equal(2));

ASSERT_TRUE(Scalar(2.0).equal(c10::complex<double>{2.0, 0.0}));
ASSERT_FALSE(Scalar(2.0).equal(c10::complex<double>{2.0, 4.0}));
ASSERT_FALSE(Scalar(2.0).equal(3.0));
ASSERT_TRUE(Scalar(2.0).equal(2));

ASSERT_TRUE(Scalar(2).equal(c10::complex<double>{2.0, 0}));
ASSERT_TRUE(Scalar(2).equal(2));
ASSERT_TRUE(Scalar(2).equal(2.0));
}
2 changes: 1 addition & 1 deletion c10/core/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class C10_API Scalar {
} else if (isIntegral(/*includeBool=*/false)) {
return v.i == num;
} else {
// boolean scalar
// boolean scalar does not equal to a non boolean value
return false;
}
}
Expand Down
10 changes: 10 additions & 0 deletions c10/util/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,16 @@ constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
return (lhs == rhs.real()) && (T() == rhs.imag());
}

template<typename T>
constexpr bool operator==(const complex<T>& lhs, const int& rhs) {
return (lhs.real() == rhs) && (lhs.imag() == T());
}

template<typename T>
constexpr bool operator==(const int& lhs, const complex<T>& rhs) {
return (lhs == rhs.real()) && (T() == rhs.imag());
}

template<typename T>
constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
return !(lhs == rhs);
Expand Down

0 comments on commit eebd3a9

Please sign in to comment.