Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.

Commit 0c345df

Browse files
authored
Calls assert_allclose in assertEqualArray (#380)
1 parent 47b00fa commit 0c345df

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

_unittests/ut_pycode/test_extunittest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,24 @@ def test_arr(self):
8484
self.assertRaise(lambda: self.assertEqualArray(None, df),
8585
AssertionError)
8686

87+
def test_arr_decimal(self):
88+
from numpy import array
89+
df = array([[0, 1], [1, 2.01]])
90+
df1 = array([[0, 1], [1, 2]])
91+
self.assertEqualArray(df, df1, decimal=1)
92+
93+
def test_arr_atol(self):
94+
from numpy import array
95+
df = array([[0.5, 1], [1, 2]])
96+
df1 = array([[0, 1], [1, 2]])
97+
self.assertEqualArray(df, df1, atol=0.5)
98+
99+
def test_arr_rtol(self):
100+
from numpy import array
101+
df = array([[0, 1], [1, 2.2]])
102+
df1 = array([[0, 1], [1, 2]])
103+
self.assertEqualArray(df, df1, rtol=0.11)
104+
87105
def test_nan(self):
88106
from numpy import array, nan
89107
df = array([[nan, 1], [1, 2]])

src/pyquickhelper/pycode/unittestclass.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,15 @@ def assertEqualArray(self, d1, d2, squeeze=False, **kwargs):
148148
raise AssertionError("d1 is None, d2 is not")
149149
if d2 is None:
150150
raise AssertionError("d1 is not None, d2 is")
151-
from numpy.testing import assert_almost_equal
152-
import numpy
151+
from numpy.testing import assert_almost_equal, assert_allclose
152+
from numpy import squeeze
153153
if squeeze:
154-
d1 = numpy.squeeze(d1)
155-
d2 = numpy.squeeze(d2)
156-
assert_almost_equal(d1, d2, **kwargs)
154+
d1 = squeeze(d1)
155+
d2 = squeeze(d2)
156+
if 'decimal' in kwargs:
157+
assert_almost_equal(d1, d2, **kwargs)
158+
else:
159+
assert_allclose(d1, d2, **kwargs)
157160

158161
def assertHasNoNan(self, a): # pylint: disable=W0221
159162
"""

0 commit comments

Comments
 (0)