Skip to content

Commit 71f14cd

Browse files
committed
FIX improve 'precompute' handling in Lars
1 parent 2537c31 commit 71f14cd

File tree

5 files changed

+88
-39
lines changed

5 files changed

+88
-39
lines changed

doc/whats_new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,13 @@ Bug fixes
375375
- Fix AIC/BIC criterion computation in :class:`linear_model.LassoLarsIC`
376376
by `Alexandre Gramfort`_ and :user:`Mehmet Basbug <mehmetbasbug>`.
377377

378+
- Fixed a bug in :class:`linear_model.RandomizedLasso`,
379+
:class:`linear_model.Lars`, :class:`linear_model.LarsLasso`,
380+
:class:`linear_model.LarsCV` and :class:`linear_model.LarsLassoCV`,
381+
where the parameter ``precompute`` were not used consistently accross
382+
classes, and some values proposed in the docstring could raise errors.
383+
:issue:`5359` by `Tom Dupre la Tour`_.
384+
378385
API changes summary
379386
-------------------
380387

sklearn/linear_model/least_angle.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,19 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500,
170170
swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (X,))
171171
solve_cholesky, = get_lapack_funcs(('potrs',), (X,))
172172

173-
if Gram is None:
173+
if Gram is None or Gram is False:
174+
Gram = None
174175
if copy_X:
175176
# force copy. setting the array to be fortran-ordered
176177
# speeds up the calculation of the (partial) Gram matrix
177178
# and allows to easily swap columns
178179
X = X.copy('F')
179-
elif isinstance(Gram, string_types) and Gram == 'auto':
180-
Gram = None
181-
if X.shape[0] > X.shape[1]:
180+
181+
elif isinstance(Gram, string_types) and Gram == 'auto' or Gram is True:
182+
if Gram is True or X.shape[0] > X.shape[1]:
182183
Gram = np.dot(X.T, X)
184+
else:
185+
Gram = None
183186
elif copy_Gram:
184187
Gram = Gram.copy()
185188

@@ -598,16 +601,14 @@ def __init__(self, fit_intercept=True, verbose=False, normalize=True,
598601
self.copy_X = copy_X
599602
self.fit_path = fit_path
600603

601-
def _get_gram(self):
602-
# precompute if n_samples > n_features
603-
precompute = self.precompute
604-
if hasattr(precompute, '__array__'):
605-
Gram = precompute
606-
elif precompute == 'auto':
607-
Gram = 'auto'
608-
else:
609-
Gram = None
610-
return Gram
604+
def _get_gram(self, precompute, X, y):
605+
if (not hasattr(precompute, '__array__')) and (
606+
(precompute is True) or
607+
(precompute == 'auto' and X.shape[0] > X.shape[1]) or
608+
(precompute == 'auto' and y.shape[1] > 1)):
609+
precompute = np.dot(X.T, X)
610+
611+
return precompute
611612

612613
def _fit(self, X, y, max_iter, alpha, fit_path, Xy=None):
613614
"""Auxiliary method to fit the model using X, y as training data"""
@@ -623,14 +624,7 @@ def _fit(self, X, y, max_iter, alpha, fit_path, Xy=None):
623624

624625
n_targets = y.shape[1]
625626

626-
precompute = self.precompute
627-
if not hasattr(precompute, '__array__') and (
628-
precompute is True or
629-
(precompute == 'auto' and X.shape[0] > X.shape[1]) or
630-
(precompute == 'auto' and y.shape[1] > 1)):
631-
Gram = np.dot(X.T, X)
632-
else:
633-
Gram = self._get_gram()
627+
Gram = self._get_gram(self.precompute, X, y)
634628

635629
self.alphas_ = []
636630
self.n_iter_ = []
@@ -1000,10 +994,10 @@ class LarsCV(Lars):
1000994
copy_X : boolean, optional, default True
1001995
If ``True``, X will be copied; else, it may be overwritten.
1002996
1003-
precompute : True | False | 'auto' | array-like
997+
precompute : True | False | 'auto'
1004998
Whether to use a precomputed Gram matrix to speed up
1005-
calculations. If set to ``'auto'`` let us decide. The Gram
1006-
matrix can also be passed as argument.
999+
calculations. If set to ``'auto'`` let us decide. The Gram matrix
1000+
cannot be passed as argument since we will use only subsets of X.
10071001
10081002
max_iter : integer, optional
10091003
Maximum number of iterations to perform.
@@ -1108,7 +1102,13 @@ def fit(self, X, y):
11081102
# init cross-validation generator
11091103
cv = check_cv(self.cv, classifier=False)
11101104

1111-
Gram = 'auto' if self.precompute else None
1105+
# As we use cross-validation, the Gram matrix is not precomputed here
1106+
Gram = self.precompute
1107+
if hasattr(Gram, '__array__'):
1108+
warnings.warn("Parameter 'precompute' cannot be an array in "
1109+
"%s. Automatically switch to 'auto' instead."
1110+
% self.__class__.__name__)
1111+
Gram = 'auto'
11121112

11131113
cv_paths = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
11141114
delayed(_lars_path_residues)(
@@ -1212,10 +1212,10 @@ class LassoLarsCV(LarsCV):
12121212
:class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
12131213
on an estimator with ``normalize=False``.
12141214
1215-
precompute : True | False | 'auto' | array-like
1215+
precompute : True | False | 'auto'
12161216
Whether to use a precomputed Gram matrix to speed up
1217-
calculations. If set to ``'auto'`` let us decide. The Gram
1218-
matrix can also be passed as argument.
1217+
calculations. If set to ``'auto'`` let us decide. The Gram matrix
1218+
cannot be passed as argument since we will use only subsets of X.
12191219
12201220
max_iter : integer, optional
12211221
Maximum number of iterations to perform.
@@ -1471,7 +1471,7 @@ def fit(self, X, y, copy_X=True):
14711471
X, y, self.fit_intercept, self.normalize, self.copy_X)
14721472
max_iter = self.max_iter
14731473

1474-
Gram = self._get_gram()
1474+
Gram = self.precompute
14751475

14761476
alphas_, active_, coef_path_, self.n_iter_ = lars_path(
14771477
X, y, Gram=Gram, copy_X=copy_X, copy_Gram=True, alpha_min=0.0,

sklearn/linear_model/randomized_l1.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _randomized_lasso(X, y, weights, mask, alpha=1., verbose=False,
157157
alpha = np.atleast_1d(np.asarray(alpha, dtype=np.float64))
158158

159159
X = (1 - weights) * X
160+
160161
with warnings.catch_warnings():
161162
warnings.simplefilter('ignore', ConvergenceWarning)
162163
alphas_, _, coef_ = lars_path(X, y,
@@ -230,10 +231,11 @@ class RandomizedLasso(BaseRandomizedLinearModel):
230231
use `preprocessing.StandardScaler` before calling `fit` on an
231232
estimator with `normalize=False`.
232233
233-
precompute : True | False | 'auto'
234-
Whether to use a precomputed Gram matrix to speed up
235-
calculations. If set to 'auto' let us decide. The Gram
236-
matrix can also be passed as argument.
234+
precompute : True | False | 'auto' | array-like
235+
Whether to use a precomputed Gram matrix to speed up calculations.
236+
If set to 'auto' let us decide.
237+
The Gram matrix can also be passed as argument, but it will be used
238+
only for the selection of parameter alpha, if alpha is 'aic' or 'bic'.
237239
238240
max_iter : integer, optional
239241
Maximum number of iterations to perform in the Lars algorithm.
@@ -334,7 +336,6 @@ def __init__(self, alpha='aic', scaling=.5, sample_fraction=.75,
334336
self.memory = memory
335337

336338
def _make_estimator_and_params(self, X, y):
337-
assert self.precompute in (True, False, None, 'auto')
338339
alpha = self.alpha
339340
if isinstance(alpha, six.string_types) and alpha in ('aic', 'bic'):
340341
model = LassoLarsIC(precompute=self.precompute,
@@ -343,9 +344,16 @@ def _make_estimator_and_params(self, X, y):
343344
eps=self.eps)
344345
model.fit(X, y)
345346
self.alpha_ = alpha = model.alpha_
347+
348+
precompute = self.precompute
349+
# A precomputed Gram array is useless, since _randomized_lasso
350+
# change X a each iteration
351+
if hasattr(precompute, '__array__'):
352+
precompute = 'auto'
353+
assert precompute in (True, False, None, 'auto')
346354
return _randomized_lasso, dict(alpha=alpha, max_iter=self.max_iter,
347355
eps=self.eps,
348-
precompute=self.precompute)
356+
precompute=precompute)
349357

350358

351359
###############################################################################

sklearn/linear_model/tests/test_least_angle.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ def test_no_path_all_precomputed():
172172
assert_true(alpha_ == alphas_[-1])
173173

174174

175+
def test_lars_precompute():
176+
# Check for different values of precompute
177+
X, y = diabetes.data, diabetes.target
178+
G = np.dot(X.T, X)
179+
for classifier in [linear_model.Lars, linear_model.LarsCV,
180+
linear_model.LassoLarsIC]:
181+
clf = classifier(precompute=G)
182+
output_1 = ignore_warnings(clf.fit)(X, y).coef_
183+
for precompute in [True, False, 'auto', None]:
184+
clf = classifier(precompute=precompute)
185+
output_2 = clf.fit(X, y).coef_
186+
assert_array_almost_equal(output_1, output_2, decimal=8)
187+
188+
175189
def test_singular_matrix():
176190
# Test when input is a singular matrix
177191
X1 = np.array([[1, 1.], [1., 1.]])

sklearn/linear_model/tests/test_randomized_l1.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,18 @@ def test_randomized_lasso():
5959
# Check randomized lasso
6060
scaling = 0.3
6161
selection_threshold = 0.5
62+
n_resampling = 20
6263

6364
# or with 1 alpha
6465
clf = RandomizedLasso(verbose=False, alpha=1, random_state=42,
65-
scaling=scaling,
66+
scaling=scaling, n_resampling=n_resampling,
6667
selection_threshold=selection_threshold)
6768
feature_scores = clf.fit(X, y).scores_
6869
assert_array_equal(np.argsort(F)[-3:], np.argsort(feature_scores)[-3:])
6970

7071
# or with many alphas
7172
clf = RandomizedLasso(verbose=False, alpha=[1, 0.8], random_state=42,
72-
scaling=scaling,
73+
scaling=scaling, n_resampling=n_resampling,
7374
selection_threshold=selection_threshold)
7475
feature_scores = clf.fit(X, y).scores_
7576
assert_equal(clf.all_scores_.shape, (X.shape[1], 2))
@@ -93,7 +94,7 @@ def test_randomized_lasso():
9394
assert_equal(X_full.shape, X.shape)
9495

9596
clf = RandomizedLasso(verbose=False, alpha='aic', random_state=42,
96-
scaling=scaling)
97+
scaling=scaling, n_resampling=100)
9798
feature_scores = clf.fit(X, y).scores_
9899
assert_allclose(feature_scores, [1., 1., 1., 0.225, 1.], rtol=0.2)
99100

@@ -104,6 +105,25 @@ def test_randomized_lasso():
104105
assert_raises(ValueError, clf.fit, X, y)
105106

106107

108+
def test_randomized_lasso_precompute():
109+
# Check randomized lasso for different values of precompute
110+
n_resampling = 20
111+
alpha = 1
112+
random_state = 42
113+
114+
G = np.dot(X.T, X)
115+
116+
clf = RandomizedLasso(alpha=alpha, random_state=random_state,
117+
precompute=G, n_resampling=n_resampling)
118+
feature_scores_1 = clf.fit(X, y).scores_
119+
120+
for precompute in [True, False, None, 'auto']:
121+
clf = RandomizedLasso(alpha=alpha, random_state=random_state,
122+
precompute=precompute, n_resampling=n_resampling)
123+
feature_scores_2 = clf.fit(X, y).scores_
124+
assert_array_equal(feature_scores_1, feature_scores_2)
125+
126+
107127
def test_randomized_logistic():
108128
# Check randomized sparse logistic regression
109129
iris = load_iris()

0 commit comments

Comments
 (0)