Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Append error msg to exception for test.assert_* method #22006

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 16 additions & 14 deletions tensorflow/python/framework/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,35 +1338,36 @@ def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
b.shape)
self.assertEqual(a.shape, b.shape, shape_mismatch_msg)

msgs = [msg]
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Prints more details than np.testing.assert_allclose.
# Adds more details to np.testing.assert_allclose.
#
# NOTE: numpy.allclose (and numpy.testing.assert_allclose)
# checks whether two arrays are element-wise equal within a
# tolerance. The relative difference (rtol * abs(b)) and the
# absolute difference atol are added together to compare against
# the absolute difference between a and b. Here, we want to
# print out which elements violate such conditions.
# tell user which elements violate such conditions.
cond = np.logical_or(
np.abs(a - b) > atol + rtol * np.abs(b),
np.isnan(a) != np.isnan(b))
if a.ndim:
x = a[np.where(cond)]
y = b[np.where(cond)]
print("not close where = ", np.where(cond))
msgs.append("not close where = {}".format(np.where(cond)))
else:
# np.where is broken for scalars
x, y = a, b
print("not close lhs = ", x)
print("not close rhs = ", y)
print("not close dif = ", np.abs(x - y))
print("not close tol = ", atol + rtol * np.abs(y))
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
msgs.append("not close lhs = {}".format(x))
msgs.append("not close rhs = {}".format(y))
msgs.append("not close dif = {}".format(np.abs(x - y)))
msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
# TODO(xpan): There seems to be a bug:
# tensorflow/compiler/tests:binary_ops_test pass with float32
# nan even though the equal_nan is False by default internally.
np.testing.assert_allclose(
a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True)
a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)

def _assertAllCloseRecursive(self,
a,
Expand Down Expand Up @@ -1548,19 +1549,20 @@ def assertAllEqual(self, a, b, msg=None):
np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
]):
same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
msgs = [msg]
if not np.all(same):
# Prints more details than np.testing.assert_array_equal.
# Adds more details to np.testing.assert_array_equal.
diff = np.logical_not(same)
if a.ndim:
x = a[np.where(diff)]
y = b[np.where(diff)]
print("not equal where = ", np.where(diff))
msgs.append("not equal where = {}".format(np.where(diff)))
else:
# np.where is broken for scalars
x, y = a, b
print("not equal lhs = ", x)
print("not equal rhs = ", y)
np.testing.assert_array_equal(a, b, err_msg=msg)
msgs.append("not equal lhs = {}".format(x))
msgs.append("not equal rhs = {}".format(y))
np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))

def assertAllGreater(self, a, comparison_target):
"""Assert element values are all greater than a target value.
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/python/framework/test_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,11 @@ def testAllCloseScalars(self):
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(7, 7 + 1e-5)

@test_util.run_in_graph_and_eager_modes
def testAllCloseList(self):
with self.assertRaisesRegexp(AssertionError, r"not close dif"):
self.assertAllClose([0], [1])

@test_util.run_in_graph_and_eager_modes
def testAllCloseDictToNonDict(self):
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
Expand Down Expand Up @@ -452,6 +457,9 @@ def testAssertAllEqual(self):
self.assertAllEqual([120] * 3, k)
self.assertAllEqual([20] * 3, j)

with self.assertRaisesRegexp(AssertionError, r"not equal lhs"):
self.assertAllEqual([0] * 3, k)

@test_util.run_in_graph_and_eager_modes
def testAssertNotAllClose(self):
# Test with arrays
Expand Down