diff --git a/_unittests/ut_pycode/test_extunittest.py b/_unittests/ut_pycode/test_extunittest.py index 3c5f3cf1..f8ec2f2c 100644 --- a/_unittests/ut_pycode/test_extunittest.py +++ b/_unittests/ut_pycode/test_extunittest.py @@ -84,6 +84,24 @@ def test_arr(self): self.assertRaise(lambda: self.assertEqualArray(None, df), AssertionError) + def test_arr_decimal(self): + from numpy import array + df = array([[0, 1], [1, 2.01]]) + df1 = array([[0, 1], [1, 2]]) + self.assertEqualArray(df, df1, decimal=1) + + def test_arr_atol(self): + from numpy import array + df = array([[0.5, 1], [1, 2]]) + df1 = array([[0, 1], [1, 2]]) + self.assertEqualArray(df, df1, atol=0.5) + + def test_arr_rtol(self): + from numpy import array + df = array([[0, 1], [1, 2.2]]) + df1 = array([[0, 1], [1, 2]]) + self.assertEqualArray(df, df1, rtol=0.11) + def test_nan(self): from numpy import array, nan df = array([[nan, 1], [1, 2]]) diff --git a/src/pyquickhelper/pycode/unittestclass.py b/src/pyquickhelper/pycode/unittestclass.py index abd57ec3..44774640 100644 --- a/src/pyquickhelper/pycode/unittestclass.py +++ b/src/pyquickhelper/pycode/unittestclass.py @@ -148,12 +148,15 @@ def assertEqualArray(self, d1, d2, squeeze=False, **kwargs): raise AssertionError("d1 is None, d2 is not") if d2 is None: raise AssertionError("d1 is not None, d2 is") - from numpy.testing import assert_almost_equal - import numpy + from numpy.testing import assert_almost_equal, assert_allclose + from numpy import squeeze if squeeze: - d1 = numpy.squeeze(d1) - d2 = numpy.squeeze(d2) - assert_almost_equal(d1, d2, **kwargs) + d1 = squeeze(d1) + d2 = squeeze(d2) + if 'decimal' in kwargs: + assert_almost_equal(d1, d2, **kwargs) + else: + assert_allclose(d1, d2, **kwargs) def assertHasNoNan(self, a): # pylint: disable=W0221 """