Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX switch to 'sparse_cg' solver in Ridge when X is sparse and fitting intercept #13995

Merged
merged 21 commits into from Jul 19, 2019

Conversation

@jeromedockes
Copy link
Contributor

jeromedockes commented May 31, 2019

Ridge is failing the check introduced in #13246 to verify that estimators produce the same results for sparse and dense data.
this PR enforces selecting the sparse_cg solver when X is sparse and fit_intercept=True, since this solver is the only one to correctly fit an intercept with the Ridge default tol and max_iter when X is sparse at the moment.
@agramfort @glemaitre @ogrisel

solver = 'sparse_cg'
if self.solver not in ['auto', 'sparse_cg']:
warnings.warn(
'setting solver to "sparse_cg" because X is sparse')

This comment has been minimized.

Copy link
@rth

rth May 31, 2019

Member

Better,

"solver={} does not support fitting the intercept on sparse data, "
"falling back to solver='sparse_cg'. To avoid this warning either change the solver "
"to 'sparse_cg' explicitly or set `fit_intercept=False`.
@@ -545,6 +545,13 @@ def fit(self, X, y, sample_weight=None):
accept_sparse=_accept_sparse,
dtype=_dtype,
multi_output=True, y_numeric=True)
if sparse.issparse(X) and self.fit_intercept:
solver = 'sparse_cg'
if self.solver not in ['auto', 'sparse_cg']:

This comment has been minimized.

Copy link
@rth

rth May 31, 2019

Member

The solver resolution (e.g. for auto is normally done in _ridge_regression) maybe better to move this there as well.

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes May 31, 2019

Author Contributor

The difficulty is that depending on fit_intercept and whether X is dense we
need to provide _ridge_regression with X_offset and X_scale (computed in
preprocessing) or not:

params = {'X_offset': X_offset, 'X_scale': X_scale}

This comment has been minimized.

Copy link
@agramfort

agramfort Jul 16, 2019

Member

I would not have prevented someone to use 'sag' with fit_intercept=True. It's not broken per se it is just that it needs a lore more iterations than the default value.

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes Jul 16, 2019

Author Contributor

In this case should there be a warning? Users may be surprised by the number of
iterations they need to set:

>>> x, y = _make_sparse_offset_regression(n_samples=20, n_features=5, random_state=0)                                                             
>>> sp_ridge = Ridge(solver='sag', max_iter=10000000, tol=1e-8).fit(sparse.csr_matrix(x), y)                                                      
>>> ridge = Ridge(solver='sag', max_iter=10000000, tol=1e-8).fit(x, y)                                                                            
>>> sp_ridge.n_iter_[0]                                                                                                                           
566250
>>> ridge.n_iter_[0]                                                                                                                              
100
>>> np.allclose(sp_ridge.intercept_, ridge.intercept_, rtol=1e-3)                                                                                 
False

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes Jul 16, 2019

Author Contributor

@agramfort I restored the possibility of using 'sag' and added a warning, let me know if I should remove the warning

assert_raises_regex(ValueError, "In Ridge,", sparse.fit, X_csr, y)
for solver in ['saga', 'lsqr', 'sag']:
sparse = Ridge(alpha=1., solver=solver, fit_intercept=True)
assert_warns(UserWarning, sparse.fit, X_csr, y)

This comment has been minimized.

Copy link
@rth

rth May 31, 2019

Member

We do need to match the warning message (a lot of unrelated things can raise a UserWarning), even if it means splitting checks for sag and other solvers.

@rth

This comment has been minimized.

Copy link
Member

rth commented May 31, 2019

Also, I wonder if it isn't better to raise an exception on an unsupported solver rather than switch solver with a warning.

@jeromedockes

This comment has been minimized.

Copy link
Contributor Author

jeromedockes commented May 31, 2019

Also, I wonder if it isn't better to raise an exception on an unsupported solver rather than switch solver with a warning.

thanks! I changed the warning to a ValueError, improved the message as you suggested, and checked the message in the test

Copy link
Contributor

glemaitre left a comment

IMO, we can consider it as a bug fix rather than a change of default.

So we will need an entry in what's new as a bug fix and document it in model changes as well.

sklearn/linear_model/ridge.py Outdated Show resolved Hide resolved
sklearn/linear_model/ridge.py Outdated Show resolved Hide resolved
sklearn/linear_model/ridge.py Outdated Show resolved Hide resolved
assert_raises_regex(ValueError, "In Ridge,", sparse.fit, X_csr, y)
for solver in ['saga', 'lsqr', 'sag']:
sparse = Ridge(alpha=1., solver=solver, fit_intercept=True)
assert_raises_regex(

This comment has been minimized.

Copy link
@glemaitre

glemaitre Jun 3, 2019

Contributor

let's use pytest for this one

jeromedockes and others added 3 commits Jun 3, 2019
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
@@ -371,7 +371,7 @@ def test_sag_regressor_computed_correctly():
n_samples = 40
max_iter = 50
tol = .000001
fit_intercept = True
fit_intercept = False

This comment has been minimized.

Copy link
@agramfort

agramfort Jun 10, 2019

Member

why do you need this @jeromedockes ? SAG fails to get a good intercept?
can you try initializing the intercept to np.mean(y_train) instead of 0 (if done like this now)?

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes Jun 11, 2019

Author Contributor

At the moment it does fail to fit a good intercept (with the default tol and
n_iter)

the intercept is indeed initialized with zeros:

init = {'coef': np.zeros((n_features + int(return_intercept), 1),

but initializing it with the mean of y probably won't change much since this
mean is 0: y is always assumed to be dense and centered in preprocessing.
The mean of X is what causes the intercept to be nonzero.

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes Jun 11, 2019

Author Contributor

I was thinking that we could revert to only allowing sparse_cg for sparse data
in this PR to quickly fix the bug and unlock #13246 , and then see if support
for sparse data can be added to the sag solver in a separate PR. WDYT?

This comment has been minimized.

Copy link
@glemaitre

glemaitre Jun 13, 2019

Contributor

I am +1 for this path. I think that we should ensure that we give proper results. We can investigate later how to fix SAG for this case. @agramfort WDYT?

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes Jun 13, 2019

Author Contributor

BTW one of the reasons the intercept takes many iterations to converge can be this decay

SPARSE_INTERCEPT_DECAY = 0.01

return dataset, intercept_decay

dataset, intercept_decay = make_dataset(X, y, sample_weight, random_state)

Copy link
Contributor

glemaitre left a comment

Couple of comments.

doc/whats_new/v0.22.rst Outdated Show resolved Hide resolved
sklearn/linear_model/ridge.py Outdated Show resolved Hide resolved
sklearn/linear_model/tests/test_ridge.py Show resolved Hide resolved
dense = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True)
sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True)
# for now only sparse_cg can fit an intercept with sparse X
for solver in ['sparse_cg']:

This comment has been minimized.

Copy link
@glemaitre

glemaitre Jun 13, 2019

Contributor

you can remove the for loop. We will use the parametrization from pytest when we will reintroduce sag

assert_raises_regex(ValueError, "In Ridge,", sparse.fit, X_csr, y)
for solver in ['saga', 'lsqr', 'sag']:
sparse = Ridge(alpha=1., solver=solver, fit_intercept=True)
with pytest.raises(

This comment has been minimized.

Copy link
@glemaitre

glemaitre Jun 13, 2019

Contributor

This will be a bit more compact

err_msg = "solver='{}' does not support".format(solver)
with pytest.raises(ValueError, match=err_msg):
    sparse.fit(X_csr, y)
sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True)
assert_raises_regex(ValueError, "In Ridge,", sparse.fit, X_csr, y)
for solver in ['saga', 'lsqr', 'sag']:
sparse = Ridge(alpha=1., solver=solver, fit_intercept=True)

This comment has been minimized.

Copy link
@glemaitre

glemaitre Jun 13, 2019

Contributor

Avoid to call it sparse. I think that we have the following import sometimes: from scipy import sparse.

This comment has been minimized.

Copy link
@agramfort
@glemaitre glemaitre changed the title switch to sparse_cg solver in Ridge when X is sparse and fit_intercept is True FIX switch to 'sparse_cg' solver in Ridge when X is sparse and fitting intercept Jun 13, 2019
jeromedockes and others added 2 commits Jun 13, 2019
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
sparse = Ridge(alpha=1., tol=1.e-15, solver=solver, fit_intercept=True)
assert_raises_regex(ValueError, "In Ridge,", sparse.fit, X_csr, y)
for solver in ['saga', 'lsqr', 'sag']:
sparse = Ridge(alpha=1., solver=solver, fit_intercept=True)

This comment has been minimized.

Copy link
@agramfort
@@ -545,6 +545,13 @@ def fit(self, X, y, sample_weight=None):
accept_sparse=_accept_sparse,
dtype=_dtype,
multi_output=True, y_numeric=True)
if sparse.issparse(X) and self.fit_intercept:
solver = 'sparse_cg'
if self.solver not in ['auto', 'sparse_cg']:

This comment has been minimized.

Copy link
@agramfort

agramfort Jul 16, 2019

Member

I would not have prevented someone to use 'sag' with fit_intercept=True. It's not broken per se it is just that it needs a lore more iterations than the default value.

@@ -109,6 +111,10 @@ Changelog
of the maximization procedure in :term:`fit`.
:pr:`13618` by :user:`Yoshihiro Uchida <c56pony>`.

- |Fix| :class:`linear_model.Ridge` now correctly fits an intercept when
`X` is sparse, `solver="auto"` and `fit_intercept=True`. Setting the solver to

This comment has been minimized.

Copy link
@agramfort

agramfort Jul 16, 2019

Member

explain that it is because the default solver is now sparse_cg

sklearn/linear_model/ridge.py Outdated Show resolved Hide resolved
'"sag" solver requires many iterations to fit '
'an intercept with sparse inputs. Either set the '
'solver to "auto" or "sparse_cg", or set a low '
'"tol" and a high "max_iter".')

This comment has been minimized.

Copy link
@agramfort

agramfort Jul 17, 2019

Member

what bothers me is that you will get the warning whatever you use for tol or max_iter. Maybe just warn if parameter used are the defaults?

# tol and max_iter, sag should raise a warning and is handled in
# test_ridge_fit_intercept_sparse_sag
# "auto" should switch to "sparse_cg"
dense_ridge = Ridge(alpha=1., solver='sparse_cg', fit_intercept=True)

This comment has been minimized.

Copy link
@agramfort

agramfort Jul 17, 2019

Member

sparse_cg for dense_ridge?

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes Jul 17, 2019

Author Contributor

sparse_cg can fit both sparse and dense data. since both "auto" and "sparse_cg" should result in "sparse_cg" being used when X is sparse, the reference is "sparse_cg" with dense data, and Ridge(solver="auto") and Ridge(solver="sparse_cg"), fitted on sparse data, are compared to it

@@ -464,6 +464,7 @@ def test_sag_regressor():
y = 0.5 * X.ravel()

clf1 = Ridge(tol=tol, solver='sag', max_iter=max_iter,
fit_intercept=False,

This comment has been minimized.

Copy link
@agramfort

agramfort Jul 17, 2019

Member

I would revert changes to test if these are not necessary. Just catch the warning if need be.

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes Jul 17, 2019

Author Contributor

I reverted them, there is no warning because in those tests the tol and max_iter used are not the default ones

jeromedockes and others added 3 commits Jul 17, 2019
Co-Authored-By: Alexandre Gramfort <alexandre.gramfort@m4x.org>
# for now only sparse_cg can fit an intercept with sparse X with default
# tol and max_iter, sag should raise a warning and is handled in
# test_ridge_fit_intercept_sparse_sag
# "auto" should switch to "sparse_cg"

This comment has been minimized.

Copy link
@agramfort

agramfort Jul 18, 2019

Member

it this comment still relevant? I don't sag warnings caught here.
If you update the comment can you write what you answered to me just below? thanks

This comment has been minimized.

Copy link
@jeromedockes

jeromedockes Jul 18, 2019

Author Contributor

thanks! updated the comment. you don't see warnings here because sag's behaviour in this configuration is tested separately in test_ridge_fit_intercept_sparse_sag

Copy link
Member

agramfort left a comment

Let's wait for another approval before merging. thx @jeromedockes

Copy link
Contributor

glemaitre left a comment

A single nitpick and then I will merge. LGTM.

sklearn/linear_model/ridge.py Outdated Show resolved Hide resolved
jeromedockes and others added 2 commits Jul 19, 2019
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
@glemaitre glemaitre merged commit 89e6c96 into scikit-learn:master Jul 19, 2019
17 checks passed
17 checks passed
LGTM analysis: C/C++ No code changes detected
Details
LGTM analysis: JavaScript No code changes detected
Details
LGTM analysis: Python No new or fixed alerts
Details
ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: doc Your tests passed on CircleCI!
Details
ci/circleci: doc-min-dependencies Your tests passed on CircleCI!
Details
ci/circleci: lint Your tests passed on CircleCI!
Details
codecov/patch 100% of diff hit (target 96.85%)
Details
codecov/project Absolute coverage decreased by -2.42% but relative coverage increased by +3.14% compared to 01d0a80
Details
scikit-learn.scikit-learn Build #20190719.34 succeeded
Details
scikit-learn.scikit-learn (Linux py35_conda_openblas) Linux py35_conda_openblas succeeded
Details
scikit-learn.scikit-learn (Linux py35_ubuntu_atlas) Linux py35_ubuntu_atlas succeeded
Details
scikit-learn.scikit-learn (Linux pylatest_conda_mkl_pandas) Linux pylatest_conda_mkl_pandas succeeded
Details
scikit-learn.scikit-learn (Linux32 py35_ubuntu_atlas_32bit) Linux32 py35_ubuntu_atlas_32bit succeeded
Details
scikit-learn.scikit-learn (Windows py35_pip_openblas_32bit) Windows py35_pip_openblas_32bit succeeded
Details
scikit-learn.scikit-learn (Windows py37_conda_mkl) Windows py37_conda_mkl succeeded
Details
scikit-learn.scikit-learn (macOS pylatest_conda_mkl) macOS pylatest_conda_mkl succeeded
Details
@glemaitre

This comment has been minimized.

Copy link
Contributor

glemaitre commented Jul 19, 2019

Thanks @jeromedockes

@jeromedockes

This comment has been minimized.

Copy link
Contributor Author

jeromedockes commented Jul 19, 2019

thanks a lot for the help and advice @glemaitre @agramfort and @rth

@jeromedockes jeromedockes deleted the jeromedockes:fix_ridge_solver_selection branch Jul 19, 2019
@amueller amueller mentioned this pull request Jul 23, 2019
11 of 11 tasks complete
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants
You can’t perform that action at this time.