Skip to content

Commit

Permalink
FIX improve 'precompute' handling in Lars
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDLT committed Jun 12, 2017
1 parent 2537c31 commit 71f14cd
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 39 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -375,6 +375,13 @@ Bug fixes
- Fix AIC/BIC criterion computation in :class:`linear_model.LassoLarsIC`
by `Alexandre Gramfort`_ and :user:`Mehmet Basbug <mehmetbasbug>`.

- Fixed a bug in :class:`linear_model.RandomizedLasso`,
:class:`linear_model.Lars`, :class:`linear_model.LarsLasso`,
:class:`linear_model.LarsCV` and :class:`linear_model.LarsLassoCV`,
where the parameter ``precompute`` were not used consistently accross
classes, and some values proposed in the docstring could raise errors.
:issue:`5359` by `Tom Dupre la Tour`_.

API changes summary
-------------------

Expand Down
60 changes: 30 additions & 30 deletions sklearn/linear_model/least_angle.py
Expand Up @@ -170,16 +170,19 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500,
swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (X,))
solve_cholesky, = get_lapack_funcs(('potrs',), (X,))

if Gram is None:
if Gram is None or Gram is False:
Gram = None
if copy_X:
# force copy. setting the array to be fortran-ordered
# speeds up the calculation of the (partial) Gram matrix
# and allows to easily swap columns
X = X.copy('F')
elif isinstance(Gram, string_types) and Gram == 'auto':
Gram = None
if X.shape[0] > X.shape[1]:

elif isinstance(Gram, string_types) and Gram == 'auto' or Gram is True:
if Gram is True or X.shape[0] > X.shape[1]:
Gram = np.dot(X.T, X)
else:
Gram = None
elif copy_Gram:
Gram = Gram.copy()

Expand Down Expand Up @@ -598,16 +601,14 @@ def __init__(self, fit_intercept=True, verbose=False, normalize=True,
self.copy_X = copy_X
self.fit_path = fit_path

def _get_gram(self):
# precompute if n_samples > n_features
precompute = self.precompute
if hasattr(precompute, '__array__'):
Gram = precompute
elif precompute == 'auto':
Gram = 'auto'
else:
Gram = None
return Gram
def _get_gram(self, precompute, X, y):
if (not hasattr(precompute, '__array__')) and (
(precompute is True) or
(precompute == 'auto' and X.shape[0] > X.shape[1]) or
(precompute == 'auto' and y.shape[1] > 1)):
precompute = np.dot(X.T, X)

return precompute

def _fit(self, X, y, max_iter, alpha, fit_path, Xy=None):
"""Auxiliary method to fit the model using X, y as training data"""
Expand All @@ -623,14 +624,7 @@ def _fit(self, X, y, max_iter, alpha, fit_path, Xy=None):

n_targets = y.shape[1]

precompute = self.precompute
if not hasattr(precompute, '__array__') and (
precompute is True or
(precompute == 'auto' and X.shape[0] > X.shape[1]) or
(precompute == 'auto' and y.shape[1] > 1)):
Gram = np.dot(X.T, X)
else:
Gram = self._get_gram()
Gram = self._get_gram(self.precompute, X, y)

self.alphas_ = []
self.n_iter_ = []
Expand Down Expand Up @@ -1000,10 +994,10 @@ class LarsCV(Lars):
copy_X : boolean, optional, default True
If ``True``, X will be copied; else, it may be overwritten.
precompute : True | False | 'auto' | array-like
precompute : True | False | 'auto'
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument.
calculations. If set to ``'auto'`` let us decide. The Gram matrix
cannot be passed as argument since we will use only subsets of X.
max_iter : integer, optional
Maximum number of iterations to perform.
Expand Down Expand Up @@ -1108,7 +1102,13 @@ def fit(self, X, y):
# init cross-validation generator
cv = check_cv(self.cv, classifier=False)

Gram = 'auto' if self.precompute else None
# As we use cross-validation, the Gram matrix is not precomputed here
Gram = self.precompute
if hasattr(Gram, '__array__'):
warnings.warn("Parameter 'precompute' cannot be an array in "
"%s. Automatically switch to 'auto' instead."
% self.__class__.__name__)
Gram = 'auto'

cv_paths = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
delayed(_lars_path_residues)(
Expand Down Expand Up @@ -1212,10 +1212,10 @@ class LassoLarsCV(LarsCV):
:class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
on an estimator with ``normalize=False``.
precompute : True | False | 'auto' | array-like
precompute : True | False | 'auto'
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument.
calculations. If set to ``'auto'`` let us decide. The Gram matrix
cannot be passed as argument since we will use only subsets of X.
max_iter : integer, optional
Maximum number of iterations to perform.
Expand Down Expand Up @@ -1471,7 +1471,7 @@ def fit(self, X, y, copy_X=True):
X, y, self.fit_intercept, self.normalize, self.copy_X)
max_iter = self.max_iter

Gram = self._get_gram()
Gram = self.precompute

alphas_, active_, coef_path_, self.n_iter_ = lars_path(
X, y, Gram=Gram, copy_X=copy_X, copy_Gram=True, alpha_min=0.0,
Expand Down
20 changes: 14 additions & 6 deletions sklearn/linear_model/randomized_l1.py
Expand Up @@ -157,6 +157,7 @@ def _randomized_lasso(X, y, weights, mask, alpha=1., verbose=False,
alpha = np.atleast_1d(np.asarray(alpha, dtype=np.float64))

X = (1 - weights) * X

with warnings.catch_warnings():
warnings.simplefilter('ignore', ConvergenceWarning)
alphas_, _, coef_ = lars_path(X, y,
Expand Down Expand Up @@ -230,10 +231,11 @@ class RandomizedLasso(BaseRandomizedLinearModel):
use `preprocessing.StandardScaler` before calling `fit` on an
estimator with `normalize=False`.
precompute : True | False | 'auto'
Whether to use a precomputed Gram matrix to speed up
calculations. If set to 'auto' let us decide. The Gram
matrix can also be passed as argument.
precompute : True | False | 'auto' | array-like
Whether to use a precomputed Gram matrix to speed up calculations.
If set to 'auto' let us decide.
The Gram matrix can also be passed as argument, but it will be used
only for the selection of parameter alpha, if alpha is 'aic' or 'bic'.
max_iter : integer, optional
Maximum number of iterations to perform in the Lars algorithm.
Expand Down Expand Up @@ -334,7 +336,6 @@ def __init__(self, alpha='aic', scaling=.5, sample_fraction=.75,
self.memory = memory

def _make_estimator_and_params(self, X, y):
assert self.precompute in (True, False, None, 'auto')
alpha = self.alpha
if isinstance(alpha, six.string_types) and alpha in ('aic', 'bic'):
model = LassoLarsIC(precompute=self.precompute,
Expand All @@ -343,9 +344,16 @@ def _make_estimator_and_params(self, X, y):
eps=self.eps)
model.fit(X, y)
self.alpha_ = alpha = model.alpha_

precompute = self.precompute
# A precomputed Gram array is useless, since _randomized_lasso
# change X a each iteration
if hasattr(precompute, '__array__'):
precompute = 'auto'
assert precompute in (True, False, None, 'auto')
return _randomized_lasso, dict(alpha=alpha, max_iter=self.max_iter,
eps=self.eps,
precompute=self.precompute)
precompute=precompute)


###############################################################################
Expand Down
14 changes: 14 additions & 0 deletions sklearn/linear_model/tests/test_least_angle.py
Expand Up @@ -172,6 +172,20 @@ def test_no_path_all_precomputed():
assert_true(alpha_ == alphas_[-1])


def test_lars_precompute():
# Check for different values of precompute
X, y = diabetes.data, diabetes.target
G = np.dot(X.T, X)
for classifier in [linear_model.Lars, linear_model.LarsCV,
linear_model.LassoLarsIC]:
clf = classifier(precompute=G)
output_1 = ignore_warnings(clf.fit)(X, y).coef_
for precompute in [True, False, 'auto', None]:
clf = classifier(precompute=precompute)
output_2 = clf.fit(X, y).coef_
assert_array_almost_equal(output_1, output_2, decimal=8)


def test_singular_matrix():
# Test when input is a singular matrix
X1 = np.array([[1, 1.], [1., 1.]])
Expand Down
26 changes: 23 additions & 3 deletions sklearn/linear_model/tests/test_randomized_l1.py
Expand Up @@ -59,17 +59,18 @@ def test_randomized_lasso():
# Check randomized lasso
scaling = 0.3
selection_threshold = 0.5
n_resampling = 20

# or with 1 alpha
clf = RandomizedLasso(verbose=False, alpha=1, random_state=42,
scaling=scaling,
scaling=scaling, n_resampling=n_resampling,
selection_threshold=selection_threshold)
feature_scores = clf.fit(X, y).scores_
assert_array_equal(np.argsort(F)[-3:], np.argsort(feature_scores)[-3:])

# or with many alphas
clf = RandomizedLasso(verbose=False, alpha=[1, 0.8], random_state=42,
scaling=scaling,
scaling=scaling, n_resampling=n_resampling,
selection_threshold=selection_threshold)
feature_scores = clf.fit(X, y).scores_
assert_equal(clf.all_scores_.shape, (X.shape[1], 2))
Expand All @@ -93,7 +94,7 @@ def test_randomized_lasso():
assert_equal(X_full.shape, X.shape)

clf = RandomizedLasso(verbose=False, alpha='aic', random_state=42,
scaling=scaling)
scaling=scaling, n_resampling=100)
feature_scores = clf.fit(X, y).scores_
assert_allclose(feature_scores, [1., 1., 1., 0.225, 1.], rtol=0.2)

Expand All @@ -104,6 +105,25 @@ def test_randomized_lasso():
assert_raises(ValueError, clf.fit, X, y)


def test_randomized_lasso_precompute():
# Check randomized lasso for different values of precompute
n_resampling = 20
alpha = 1
random_state = 42

G = np.dot(X.T, X)

clf = RandomizedLasso(alpha=alpha, random_state=random_state,
precompute=G, n_resampling=n_resampling)
feature_scores_1 = clf.fit(X, y).scores_

for precompute in [True, False, None, 'auto']:
clf = RandomizedLasso(alpha=alpha, random_state=random_state,
precompute=precompute, n_resampling=n_resampling)
feature_scores_2 = clf.fit(X, y).scores_
assert_array_equal(feature_scores_1, feature_scores_2)


def test_randomized_logistic():
# Check randomized sparse logistic regression
iris = load_iris()
Expand Down

0 comments on commit 71f14cd

Please sign in to comment.