Skip to content

Commit

Permalink
TST Add test with reduced tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman committed Sep 8, 2016
1 parent f58fdf7 commit b93067f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
13 changes: 11 additions & 2 deletions sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,12 +853,21 @@ def _deterministic_vector_sign_flip(u):
return u


def stable_cumsum(arr):
def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
"""Use high precision for cumsum and check that final value matches sum
Parameters
----------
arr : array-like
To be cumulatively summed as flat
rtol : float
Relative tolerance, see ``np.allclose``
atol : float
Absolute tolerance, see ``np.allclose``
"""
out = np.cumsum(arr, dtype=np.float64)
expected = np.sum(arr, dtype=np.float64)
if not np.allclose(out[-1], expected):
if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
raise RuntimeError('cumsum was found to be unstable: '
'its last element does not correspond to sum')
return out
11 changes: 11 additions & 0 deletions sklearn/utils/tests/test_extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.utils.testing import assert_false
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import skip_if_32bit

from sklearn.utils.extmath import density
Expand All @@ -32,6 +33,7 @@
from sklearn.utils.extmath import _incremental_mean_and_var
from sklearn.utils.extmath import _deterministic_vector_sign_flip
from sklearn.utils.extmath import softmax
from sklearn.utils.extmath import stable_cumsum
from sklearn.datasets.samples_generator import make_low_rank_matrix


Expand Down Expand Up @@ -643,3 +645,12 @@ def test_softmax():
exp_X = np.exp(X)
sum_exp_X = np.sum(exp_X, axis=1).reshape((-1, 1))
assert_array_almost_equal(softmax(X), exp_X / sum_exp_X)


def test_stable_cumsum():
assert_array_equal(stable_cumsum([1, 2, 3]), np.cumsum([1, 2, 3]))
r = np.random.RandomState(0).rand(100000)
assert_raise_message(RuntimeError,
'cumsum was found to be unstable: its last element '
'does not correspond to sum',
stable_cumsum, r, rtol=0, atol=0)

0 comments on commit b93067f

Please sign in to comment.