Skip to content
Permalink
Browse files

[tensor] Make isEqual work correctly with NAN (#2443)

  • Loading branch information...
bertmaher committed Feb 27, 2019
1 parent ae9d948 commit 1c8eadb20b46657f2dc35c18b1e47ebc4b219066
Showing with 28 additions and 7 deletions.
  1. +7 −6 cmake/modules/GlowDefaults.cmake
  2. +3 −1 include/glow/Base/Tensor.h
  3. +18 −0 tests/unittests/TensorsTest.cpp
@@ -33,18 +33,19 @@ else()
include(CheckCXXCompilerFlag)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wnon-virtual-dtor -fno-exceptions -fno-rtti")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-omit-frame-pointer -O0")
set(FAST_MATH_FLAGS "-ffast-math -fno-finite-math-only")
CHECK_CXX_COMPILER_FLAG("-Wno-psabi" HAS_W_NO_PSABI)
if(HAS_W_NO_PSABI)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-psabi")
endif()
if((CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL armv7)
AND CMAKE_CXX_COMPILER_ID STREQUAL Clang)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -ffast-math")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -ffast-math")
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} -ffast-math")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${FAST_MATH_FLAGS}")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} ${FAST_MATH_FLAGS}")
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} ${FAST_MATH_FLAGS}")
else()
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -ffast-math")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -march=native -ffast-math")
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} -march=native -ffast-math")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native ${FAST_MATH_FLAGS}")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -march=native ${FAST_MATH_FLAGS}")
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} -march=native ${FAST_MATH_FLAGS}")
endif()
endif()
@@ -472,7 +472,9 @@ class Tensor final {
auto const *otherData = other.getRawDataPointer<ElemTy>();
for (size_t i = 0, e = size(); i < e; i++) {
double delta = myData[i] - otherData[i];
if (std::abs(delta) > allowedError) {
// Since any comparison with NAN returns false, we use a negated condition
// so that this function correctly returns false when delta is NAN.
if (!(std::abs(delta) <= allowedError)) {
return false;
}
}
@@ -153,6 +153,24 @@ TEST(Tensor, equalHandles) {
}
}

TEST(Tensor, equalNAN) {
{
Tensor A = {0.5, 0, 0, 25};
Tensor B = {NAN, 0, NAN, NAN};
EXPECT_FALSE(A.isEqual(B));
}
{
Tensor A = {NAN, 0, NAN, NAN};
Tensor B = {0.5, 0, 0, 25};
EXPECT_FALSE(A.isEqual(B));
}
{
Tensor A = {NAN, 0, NAN, NAN};
Tensor B = {NAN, 0, NAN, NAN};
EXPECT_FALSE(A.isEqual(B));
}
}

template <typename Ty> void testAssignment(const Type &ty) {
// Testing some tensor operations.
Tensor T(ty);

0 comments on commit 1c8eadb

Please sign in to comment.
You can’t perform that action at this time.