From 71f14cd9e6dbb4566f4e5836b6b43ba3d1725f99 Mon Sep 17 00:00:00 2001 From: Tom DLT Date: Wed, 7 Oct 2015 15:46:41 +0200 Subject: [PATCH] FIX improve 'precompute' handling in Lars --- doc/whats_new.rst | 7 +++ sklearn/linear_model/least_angle.py | 60 +++++++++---------- sklearn/linear_model/randomized_l1.py | 20 +++++-- .../linear_model/tests/test_least_angle.py | 14 +++++ .../linear_model/tests/test_randomized_l1.py | 26 +++++++- 5 files changed, 88 insertions(+), 39 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 9730cdcfb9c11..b1d088fdacd48 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -375,6 +375,13 @@ Bug fixes - Fix AIC/BIC criterion computation in :class:`linear_model.LassoLarsIC` by `Alexandre Gramfort`_ and :user:`Mehmet Basbug `. + - Fixed a bug in :class:`linear_model.RandomizedLasso`, + :class:`linear_model.Lars`, :class:`linear_model.LarsLasso`, + :class:`linear_model.LarsCV` and :class:`linear_model.LarsLassoCV`, + where the parameter ``precompute`` were not used consistently accross + classes, and some values proposed in the docstring could raise errors. + :issue:`5359` by `Tom Dupre la Tour`_. + API changes summary ------------------- diff --git a/sklearn/linear_model/least_angle.py b/sklearn/linear_model/least_angle.py index 4878ea23acb84..dfd7acb01993e 100644 --- a/sklearn/linear_model/least_angle.py +++ b/sklearn/linear_model/least_angle.py @@ -170,16 +170,19 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500, swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (X,)) solve_cholesky, = get_lapack_funcs(('potrs',), (X,)) - if Gram is None: + if Gram is None or Gram is False: + Gram = None if copy_X: # force copy. setting the array to be fortran-ordered # speeds up the calculation of the (partial) Gram matrix # and allows to easily swap columns X = X.copy('F') - elif isinstance(Gram, string_types) and Gram == 'auto': - Gram = None - if X.shape[0] > X.shape[1]: + + elif isinstance(Gram, string_types) and Gram == 'auto' or Gram is True: + if Gram is True or X.shape[0] > X.shape[1]: Gram = np.dot(X.T, X) + else: + Gram = None elif copy_Gram: Gram = Gram.copy() @@ -598,16 +601,14 @@ def __init__(self, fit_intercept=True, verbose=False, normalize=True, self.copy_X = copy_X self.fit_path = fit_path - def _get_gram(self): - # precompute if n_samples > n_features - precompute = self.precompute - if hasattr(precompute, '__array__'): - Gram = precompute - elif precompute == 'auto': - Gram = 'auto' - else: - Gram = None - return Gram + def _get_gram(self, precompute, X, y): + if (not hasattr(precompute, '__array__')) and ( + (precompute is True) or + (precompute == 'auto' and X.shape[0] > X.shape[1]) or + (precompute == 'auto' and y.shape[1] > 1)): + precompute = np.dot(X.T, X) + + return precompute def _fit(self, X, y, max_iter, alpha, fit_path, Xy=None): """Auxiliary method to fit the model using X, y as training data""" @@ -623,14 +624,7 @@ def _fit(self, X, y, max_iter, alpha, fit_path, Xy=None): n_targets = y.shape[1] - precompute = self.precompute - if not hasattr(precompute, '__array__') and ( - precompute is True or - (precompute == 'auto' and X.shape[0] > X.shape[1]) or - (precompute == 'auto' and y.shape[1] > 1)): - Gram = np.dot(X.T, X) - else: - Gram = self._get_gram() + Gram = self._get_gram(self.precompute, X, y) self.alphas_ = [] self.n_iter_ = [] @@ -1000,10 +994,10 @@ class LarsCV(Lars): copy_X : boolean, optional, default True If ``True``, X will be copied; else, it may be overwritten. - precompute : True | False | 'auto' | array-like + precompute : True | False | 'auto' Whether to use a precomputed Gram matrix to speed up - calculations. If set to ``'auto'`` let us decide. The Gram - matrix can also be passed as argument. + calculations. If set to ``'auto'`` let us decide. The Gram matrix + cannot be passed as argument since we will use only subsets of X. max_iter : integer, optional Maximum number of iterations to perform. @@ -1108,7 +1102,13 @@ def fit(self, X, y): # init cross-validation generator cv = check_cv(self.cv, classifier=False) - Gram = 'auto' if self.precompute else None + # As we use cross-validation, the Gram matrix is not precomputed here + Gram = self.precompute + if hasattr(Gram, '__array__'): + warnings.warn("Parameter 'precompute' cannot be an array in " + "%s. Automatically switch to 'auto' instead." + % self.__class__.__name__) + Gram = 'auto' cv_paths = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)( delayed(_lars_path_residues)( @@ -1212,10 +1212,10 @@ class LassoLarsCV(LarsCV): :class:`sklearn.preprocessing.StandardScaler` before calling ``fit`` on an estimator with ``normalize=False``. - precompute : True | False | 'auto' | array-like + precompute : True | False | 'auto' Whether to use a precomputed Gram matrix to speed up - calculations. If set to ``'auto'`` let us decide. The Gram - matrix can also be passed as argument. + calculations. If set to ``'auto'`` let us decide. The Gram matrix + cannot be passed as argument since we will use only subsets of X. max_iter : integer, optional Maximum number of iterations to perform. @@ -1471,7 +1471,7 @@ def fit(self, X, y, copy_X=True): X, y, self.fit_intercept, self.normalize, self.copy_X) max_iter = self.max_iter - Gram = self._get_gram() + Gram = self.precompute alphas_, active_, coef_path_, self.n_iter_ = lars_path( X, y, Gram=Gram, copy_X=copy_X, copy_Gram=True, alpha_min=0.0, diff --git a/sklearn/linear_model/randomized_l1.py b/sklearn/linear_model/randomized_l1.py index 5ee0782b7f2a2..ba6a424a96ff2 100644 --- a/sklearn/linear_model/randomized_l1.py +++ b/sklearn/linear_model/randomized_l1.py @@ -157,6 +157,7 @@ def _randomized_lasso(X, y, weights, mask, alpha=1., verbose=False, alpha = np.atleast_1d(np.asarray(alpha, dtype=np.float64)) X = (1 - weights) * X + with warnings.catch_warnings(): warnings.simplefilter('ignore', ConvergenceWarning) alphas_, _, coef_ = lars_path(X, y, @@ -230,10 +231,11 @@ class RandomizedLasso(BaseRandomizedLinearModel): use `preprocessing.StandardScaler` before calling `fit` on an estimator with `normalize=False`. - precompute : True | False | 'auto' - Whether to use a precomputed Gram matrix to speed up - calculations. If set to 'auto' let us decide. The Gram - matrix can also be passed as argument. + precompute : True | False | 'auto' | array-like + Whether to use a precomputed Gram matrix to speed up calculations. + If set to 'auto' let us decide. + The Gram matrix can also be passed as argument, but it will be used + only for the selection of parameter alpha, if alpha is 'aic' or 'bic'. max_iter : integer, optional Maximum number of iterations to perform in the Lars algorithm. @@ -334,7 +336,6 @@ def __init__(self, alpha='aic', scaling=.5, sample_fraction=.75, self.memory = memory def _make_estimator_and_params(self, X, y): - assert self.precompute in (True, False, None, 'auto') alpha = self.alpha if isinstance(alpha, six.string_types) and alpha in ('aic', 'bic'): model = LassoLarsIC(precompute=self.precompute, @@ -343,9 +344,16 @@ def _make_estimator_and_params(self, X, y): eps=self.eps) model.fit(X, y) self.alpha_ = alpha = model.alpha_ + + precompute = self.precompute + # A precomputed Gram array is useless, since _randomized_lasso + # change X a each iteration + if hasattr(precompute, '__array__'): + precompute = 'auto' + assert precompute in (True, False, None, 'auto') return _randomized_lasso, dict(alpha=alpha, max_iter=self.max_iter, eps=self.eps, - precompute=self.precompute) + precompute=precompute) ############################################################################### diff --git a/sklearn/linear_model/tests/test_least_angle.py b/sklearn/linear_model/tests/test_least_angle.py index 53df763b05c8e..0586b8433943d 100644 --- a/sklearn/linear_model/tests/test_least_angle.py +++ b/sklearn/linear_model/tests/test_least_angle.py @@ -172,6 +172,20 @@ def test_no_path_all_precomputed(): assert_true(alpha_ == alphas_[-1]) +def test_lars_precompute(): + # Check for different values of precompute + X, y = diabetes.data, diabetes.target + G = np.dot(X.T, X) + for classifier in [linear_model.Lars, linear_model.LarsCV, + linear_model.LassoLarsIC]: + clf = classifier(precompute=G) + output_1 = ignore_warnings(clf.fit)(X, y).coef_ + for precompute in [True, False, 'auto', None]: + clf = classifier(precompute=precompute) + output_2 = clf.fit(X, y).coef_ + assert_array_almost_equal(output_1, output_2, decimal=8) + + def test_singular_matrix(): # Test when input is a singular matrix X1 = np.array([[1, 1.], [1., 1.]]) diff --git a/sklearn/linear_model/tests/test_randomized_l1.py b/sklearn/linear_model/tests/test_randomized_l1.py index f1744876c710b..37eb66faab339 100644 --- a/sklearn/linear_model/tests/test_randomized_l1.py +++ b/sklearn/linear_model/tests/test_randomized_l1.py @@ -59,17 +59,18 @@ def test_randomized_lasso(): # Check randomized lasso scaling = 0.3 selection_threshold = 0.5 + n_resampling = 20 # or with 1 alpha clf = RandomizedLasso(verbose=False, alpha=1, random_state=42, - scaling=scaling, + scaling=scaling, n_resampling=n_resampling, selection_threshold=selection_threshold) feature_scores = clf.fit(X, y).scores_ assert_array_equal(np.argsort(F)[-3:], np.argsort(feature_scores)[-3:]) # or with many alphas clf = RandomizedLasso(verbose=False, alpha=[1, 0.8], random_state=42, - scaling=scaling, + scaling=scaling, n_resampling=n_resampling, selection_threshold=selection_threshold) feature_scores = clf.fit(X, y).scores_ assert_equal(clf.all_scores_.shape, (X.shape[1], 2)) @@ -93,7 +94,7 @@ def test_randomized_lasso(): assert_equal(X_full.shape, X.shape) clf = RandomizedLasso(verbose=False, alpha='aic', random_state=42, - scaling=scaling) + scaling=scaling, n_resampling=100) feature_scores = clf.fit(X, y).scores_ assert_allclose(feature_scores, [1., 1., 1., 0.225, 1.], rtol=0.2) @@ -104,6 +105,25 @@ def test_randomized_lasso(): assert_raises(ValueError, clf.fit, X, y) +def test_randomized_lasso_precompute(): + # Check randomized lasso for different values of precompute + n_resampling = 20 + alpha = 1 + random_state = 42 + + G = np.dot(X.T, X) + + clf = RandomizedLasso(alpha=alpha, random_state=random_state, + precompute=G, n_resampling=n_resampling) + feature_scores_1 = clf.fit(X, y).scores_ + + for precompute in [True, False, None, 'auto']: + clf = RandomizedLasso(alpha=alpha, random_state=random_state, + precompute=precompute, n_resampling=n_resampling) + feature_scores_2 = clf.fit(X, y).scores_ + assert_array_equal(feature_scores_1, feature_scores_2) + + def test_randomized_logistic(): # Check randomized sparse logistic regression iris = load_iris()