Skip to content

Commit

Permalink
[MRG+1] FIX unstable cumsum (scikit-learn#7376)
Browse files Browse the repository at this point in the history
* FIX unstable cumsum in utils.random

* equal_nan = true for isclose
since numpy < 1.9 sum is as unstable as cumsum, fallback to np.cumsum

* added axis parameter to stable_cumsum

* FIX unstable sumsum in ensemble.weight_boosting and utils.stats

* FIX axis problem in stable_cumsum

* FIX unstable cumsum in mixture.gmm and mixture.dpgmm

* FIX unstable cumsum in cluster.k_means_, decomposition.pca, and manifold.locally_linear

* FIX unstable sumsum in dataset.samples_generator

* added docstring for parameter axis of stable_cumsum

* added comment for why fall back to np.cumsum when np version < 1.9

* remove unneeded stable_cumsum

* added stable_cumsum's axis testing

* FIX numpy docstring for make_sparse_spd_matrix

* change stable_cumsum from error to warning
  • Loading branch information
yangarbiter authored and paulha committed Aug 19, 2017
1 parent 4058632 commit 58a2088
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 21 deletions.
5 changes: 3 additions & 2 deletions sklearn/cluster/k_means_.py
Expand Up @@ -18,7 +18,7 @@

from ..base import BaseEstimator, ClusterMixin, TransformerMixin
from ..metrics.pairwise import euclidean_distances
from ..utils.extmath import row_norms, squared_norm
from ..utils.extmath import row_norms, squared_norm, stable_cumsum
from ..utils.sparsefuncs_fast import assign_rows_csr
from ..utils.sparsefuncs import mean_variance_axis
from ..utils.fixes import astype
Expand Down Expand Up @@ -106,7 +106,8 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
# Choose center candidates by sampling with probability proportional
# to the squared distance to the closest existing center
rand_vals = random_state.random_sample(n_local_trials) * current_pot
candidate_ids = np.searchsorted(closest_dist_sq.cumsum(), rand_vals)
candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq),
rand_vals)

# Compute distances to center candidates
distance_to_candidates = euclidean_distances(
Expand Down
2 changes: 1 addition & 1 deletion sklearn/datasets/samples_generator.py
Expand Up @@ -1194,7 +1194,7 @@ def make_sparse_spd_matrix(dim=1, alpha=0.95, norm_diag=False,
The size of the random matrix to generate.
alpha : float between 0 and 1, optional (default=0.95)
The probability that a coefficient is zero (see notes). Larger values
The probability that a coefficient is zero (see notes). Larger values
enforce more sparsity.
random_state : int, RandomState instance or None, optional (default=None)
Expand Down
3 changes: 2 additions & 1 deletion sklearn/decomposition/pca.py
Expand Up @@ -24,6 +24,7 @@
from ..utils import check_random_state, as_float_array
from ..utils import check_array
from ..utils.extmath import fast_dot, fast_logdet, randomized_svd, svd_flip
from ..utils.extmath import stable_cumsum
from ..utils.validation import check_is_fitted
from ..utils.arpack import svds

Expand Down Expand Up @@ -393,7 +394,7 @@ def _fit_full(self, X, n_components):
elif 0 < n_components < 1.0:
# number of components for which the cumulated explained
# variance percentage is superior to the desired threshold
ratio_cumsum = explained_variance_ratio_.cumsum()
ratio_cumsum = stable_cumsum(explained_variance_ratio_)
n_components = np.searchsorted(ratio_cumsum, n_components) + 1

# Compute noise covariance using Probabilistic PCA model
Expand Down
5 changes: 3 additions & 2 deletions sklearn/ensemble/weight_boosting.py
Expand Up @@ -38,6 +38,7 @@
from ..tree.tree import BaseDecisionTree
from ..tree._tree import DTYPE
from ..utils import check_array, check_X_y, check_random_state
from ..utils.extmath import stable_cumsum
from ..metrics import accuracy_score, r2_score
from sklearn.utils.validation import has_fit_parameter, check_is_fitted

Expand Down Expand Up @@ -1002,7 +1003,7 @@ def _boost(self, iboost, X, y, sample_weight, random_state):

# Weighted sampling of the training set with replacement
# For NumPy >= 1.7.0 use np.random.choice
cdf = sample_weight.cumsum()
cdf = stable_cumsum(sample_weight)
cdf /= cdf[-1]
uniform_samples = random_state.random_sample(X.shape[0])
bootstrap_idx = cdf.searchsorted(uniform_samples, side='right')
Expand Down Expand Up @@ -1059,7 +1060,7 @@ def _get_median_predict(self, X, limit):
sorted_idx = np.argsort(predictions, axis=1)

# Find index of median prediction for each sample
weight_cdf = self.estimator_weights_[sorted_idx].cumsum(axis=1)
weight_cdf = stable_cumsum(self.estimator_weights_[sorted_idx], axis=1)
median_or_above = weight_cdf >= 0.5 * weight_cdf[:, -1][:, np.newaxis]
median_idx = median_or_above.argmax(axis=1)

Expand Down
3 changes: 2 additions & 1 deletion sklearn/manifold/locally_linear.py
Expand Up @@ -10,6 +10,7 @@
from ..base import BaseEstimator, TransformerMixin
from ..utils import check_random_state, check_array
from ..utils.arpack import eigsh
from ..utils.extmath import stable_cumsum
from ..utils.validation import check_is_fitted
from ..utils.validation import FLOAT_DTYPES
from ..neighbors import NearestNeighbors
Expand Down Expand Up @@ -420,7 +421,7 @@ def locally_linear_embedding(
# this is the size of the largest set of eigenvalues
# such that Sum[v; v in set]/Sum[v; v not in set] < eta
s_range = np.zeros(N, dtype=int)
evals_cumsum = np.cumsum(evals, 1)
evals_cumsum = stable_cumsum(evals, 1)
eta_range = evals_cumsum[:, -1:] / evals_cumsum[:, :-1] - 1
for i in range(N):
s_range[i] = np.searchsorted(eta_range[i, ::-1], eta)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/mixture/dpgmm.py
Expand Up @@ -24,7 +24,7 @@

from ..externals.six.moves import xrange
from ..utils import check_random_state, check_array, deprecated
from ..utils.extmath import logsumexp, pinvh, squared_norm
from ..utils.extmath import logsumexp, pinvh, squared_norm, stable_cumsum
from ..utils.validation import check_is_fitted
from .. import cluster
from .gmm import _GMMBase
Expand Down Expand Up @@ -462,7 +462,7 @@ def _bound_proportions(self, z):
dg1 = digamma(self.gamma_.T[1]) - dg12
dg2 = digamma(self.gamma_.T[2]) - dg12

cz = np.cumsum(z[:, ::-1], axis=-1)[:, -2::-1]
cz = stable_cumsum(z[:, ::-1], axis=-1)[:, -2::-1]
logprior = np.sum(cz * dg2[:-1]) + np.sum(z * dg1)
del cz # Save memory
z_non_zeros = z[z > np.finfo(np.float32).eps]
Expand Down
23 changes: 16 additions & 7 deletions sklearn/utils/extmath.py
Expand Up @@ -25,7 +25,7 @@
from ..externals.six.moves import xrange
from .sparsefuncs_fast import csr_row_norms
from .validation import check_array
from ..exceptions import NonBLASDotWarning
from ..exceptions import ConvergenceWarning, NonBLASDotWarning


def norm(x):
Expand Down Expand Up @@ -844,21 +844,30 @@ def _deterministic_vector_sign_flip(u):
return u


def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08):
"""Use high precision for cumsum and check that final value matches sum
Parameters
----------
arr : array-like
To be cumulatively summed as flat
axis : int, optional
Axis along which the cumulative sum is computed.
The default (None) is to compute the cumsum over the flattened array.
rtol : float
Relative tolerance, see ``np.allclose``
atol : float
Absolute tolerance, see ``np.allclose``
"""
out = np.cumsum(arr, dtype=np.float64)
expected = np.sum(arr, dtype=np.float64)
if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
raise RuntimeError('cumsum was found to be unstable: '
'its last element does not correspond to sum')
# sum is as unstable as cumsum for numpy < 1.9
if np_version < (1, 9):
return np.cumsum(arr, axis=axis, dtype=np.float64)

out = np.cumsum(arr, axis=axis, dtype=np.float64)
expected = np.sum(arr, axis=axis, dtype=np.float64)
if not np.all(np.isclose(out.take(-1, axis=axis), expected, rtol=rtol,
atol=atol, equal_nan=True)):
warnings.warn('cumsum was found to be unstable: '
'its last element does not correspond to sum',
ConvergenceWarning)
return out
3 changes: 2 additions & 1 deletion sklearn/utils/stats.py
@@ -1,6 +1,7 @@
import numpy as np
from scipy.stats import rankdata as _sp_rankdata
from .fixes import bincount
from ..utils.extmath import stable_cumsum


# To remove when we support scipy 0.13
Expand Down Expand Up @@ -53,7 +54,7 @@ def _weighted_percentile(array, sample_weight, percentile=50):
sorted_idx = np.argsort(array)

# Find index of median prediction for each sample
weight_cdf = sample_weight[sorted_idx].cumsum()
weight_cdf = stable_cumsum(sample_weight[sorted_idx])
percentile_idx = np.searchsorted(
weight_cdf, (percentile / 100.) * weight_cdf[-1])
return array[sorted_idx[percentile_idx]]
13 changes: 9 additions & 4 deletions sklearn/utils/tests/test_extmath.py
Expand Up @@ -18,6 +18,7 @@
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import skip_if_32bit
from sklearn.utils.testing import SkipTest
from sklearn.utils.fixes import np_version
Expand All @@ -36,6 +37,7 @@
from sklearn.utils.extmath import _deterministic_vector_sign_flip
from sklearn.utils.extmath import softmax
from sklearn.utils.extmath import stable_cumsum
from sklearn.exceptions import ConvergenceWarning
from sklearn.datasets.samples_generator import make_low_rank_matrix


Expand Down Expand Up @@ -654,7 +656,10 @@ def test_stable_cumsum():
raise SkipTest("Sum is as unstable as cumsum for numpy < 1.9")
assert_array_equal(stable_cumsum([1, 2, 3]), np.cumsum([1, 2, 3]))
r = np.random.RandomState(0).rand(100000)
assert_raise_message(RuntimeError,
'cumsum was found to be unstable: its last element '
'does not correspond to sum',
stable_cumsum, r, rtol=0, atol=0)
assert_warns(ConvergenceWarning, stable_cumsum, r, rtol=0, atol=0)

# test axis parameter
A = np.random.RandomState(36).randint(1000, size=(5, 5, 5))
assert_array_equal(stable_cumsum(A, axis=0), np.cumsum(A, axis=0))
assert_array_equal(stable_cumsum(A, axis=1), np.cumsum(A, axis=1))
assert_array_equal(stable_cumsum(A, axis=2), np.cumsum(A, axis=2))

0 comments on commit 58a2088

Please sign in to comment.