Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Sample weights for ElasticNet #15436

Merged
merged 27 commits into from
Feb 16, 2020

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Nov 1, 2019

Reference Issues/PRs

Partially solves #3702: Adds sample_weight to ElasticNet and Lasso, but only for dense feature array X.

@lorentzenchr lorentzenchr changed the title [WIP] Sample weights for ElasticNet [MRG] Sample weights for ElasticNet Nov 6, 2019
Copy link
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @lorentzenchr ! Haven't done a full review, but a few comments are below.

@@ -37,7 +37,7 @@ def test_transform_target_regressor_error():
regr.fit(X, y)
# fit with sample_weight with a regressor which does not support it
sample_weight = np.ones((y.shape[0],))
regr = TransformedTargetRegressor(regressor=Lasso(),
regr = TransformedTargetRegressor(regressor=OrthogonalMatchingPursuit(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, one needs an estimator that does not support sample_weight. As Lasso inherits sample_weight support from ElasticNet, I needed to change this.

# Ensure copying happens only once, don't do it again if done above
should_copy = self.copy_X and not X_copied
X, y, X_offset, y_offset, X_scale, precompute, Xy = \
_pre_fit(X, y, None, self.precompute, self.normalize,
self.fit_intercept, copy=should_copy,
check_input=check_input)
check_input=check_input, sample_weight=sample_weight,
order='F')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth adding a comment about why order="F" is more efficient here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only more efficient code at this place, not better performance.

Copy link
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this adds support for sample weight by rescaling X by sqrt(sample_weight). I know that it is done in other linear model, but how can we be sure that this is the right thing to do here with L1 regularization?

It would have been nice to have a test comparing results using sample_weight with Ridge when l1_ratio=0. For l1_ratio > 0 do you have any suggestions for checking in unit tests that the output is correct (beyond being consistent with respect to expected invariances) ?

For all the added conversions for array order and sparse format, those are optimizations orthogonal to adding sample_weights, right? If it might be better to separate that code into a separate private function, used when necessary with some justification that this does impact performance in a good way (it's hard to judge which order is optimal without actually benchmarking it for me). This part could also be a separate PR.

if sparse_X:
# As of scipy 1.1.0, new argument copy=False by default.
# This is what we want.
X = X.asformat(sparse_format)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a sklearn.utils.fixes._astype_copy_false that could be used for this,

Suggested change
X = X.asformat(sparse_format)
X = X.asformat(sparse_format, **_astype_copy_false(X))

@@ -191,6 +211,20 @@ def _rescale_data(X, y, sample_weight):
shape=(n_samples, n_samples))
X = safe_sparse_dot(sw_matrix, X)
y = safe_sparse_dot(sw_matrix, y)

if order is not None:
sparse_format = "csc" if order == "F" else "csr"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this somewhat not obvious, when calling _rescale_data(..., order='C') on a sparse array. In fact this doesn't need to be part of _rescale_data it could be a separate private function enforcing order when needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this in a private function _set_order in __coordinate_descent.py. It's the only place where it is used.

@lorentzenchr
Copy link
Member Author

@rth You raise some good questions.

how can we be sure that this is the right thing to do here with L1 regularization?

Because rescaling by the square root works and I thought it is the minimal code change. If I were to code it from scratch, I would include sample weights in the loss function, it's gradient and hessian and pass it to the optimizer.

It would have been nice to have a test comparing results using sample_weight with Ridge when l1_ratio=0.

I will try to include such a test.

For l1_ratio > 0 do you have any suggestions for checking in unit tests that the output is correct?

One can check against glmnet. What do you think?

For all the added conversions for array order and sparse format, those are optimizations orthogonal to adding sample_weights, right?

The only solver for L1-penalty in scikit-learn is coordinate descent (see also #12966). It is written in Cython and expects F-ordered arrays which is optimal for this algo. All I did was to avoid memory copies and pass F-ordered arrays to _cd_fast. Though I might have missed something and I'm not an expert on memoryviews in Cython.
Therefore, I would prefer to keep this PR together.

@rth
Copy link
Member

rth commented Feb 6, 2020

how can we be sure that this is the right thing to do here with L1 regularization?

Because rescaling by the square root works and I thought it is the minimal code change.

Indeed. Actually, if we take as definition of sample weights that a weight of 2 is equivalent to 2 repeated samples etc, #15651 (comment) then checking this invariances as you have done could be enough, and I'm actually fine with this without a comparison to glmnet or other external package. Other parameters (alpha, l1_ratio) should be invariant to the addition of sample weights, as far as I understand (to come back to #15657 (comment) discussion).

Also cc @agramfort

All I did was to avoid memory copies and pass F-ordered arrays to _cd_fast.
Therefore, I would prefer to keep this PR together.

OK, fair enough.

if check_input:
if sparse.issparse(X):
raise ValueError("Sample weights do not (yet) support "
"sparse matrices.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see tests with sample weights and sparse matrices why do you have this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is test_enet_sample_weight_sparse, whicht tests that ElasticNetc raises an error as sample weights are not supported with sparse X.
Secondly, there is test_enet_sample_weight_consistency which is decorated with @pytest.mark.parametrize('sparseX', [False]) such that this test could easily be extended once sparse X are supported with sample weights.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I had not paid attention to @pytest.mark.parametrize('sparseX', [False]) We tend to add code for tests when it becomes relevant. Do you plan to this next when this is merged?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say that this is needed to close #3702 and that is is a useful feature. I'm tempted to try it in a new PR, but I'm afraid that it needs a larger code change than this one.

Christian Lorentzen and others added 3 commits February 9, 2020 15:36
Co-Authored-By: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Co-Authored-By: Alexandre Gramfort <alexandre.gramfort@m4x.org>
@agramfort
Copy link
Member

2 more things on my side. You have left comments from @rth to address or rule out. I would not add code/test for things we don't use and we don't know if the new code is actually correct. I would therefore remove the sparseX option in the test. Let's had code / tests when we actually use it. thx @lorentzenchr

@lorentzenchr
Copy link
Member Author

@agramfort Thanks for pointing out unaddressed issues. I hope it is in a good shape now.
I suspect and hope that the test error of doc-min-dependencies in unrelated to this PR.

Copy link
Member

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rth feel free to have another look

thx @lorentzenchr

@agramfort
Copy link
Member

@lorentzenchr just don't forget to update what's new page.

:mod:`sklearn.linear_model`
...........................

- |Feature| Support of `sample_weight` in :class:`linear_model.ElasticNet` and
:class:`linear_model:Lasso` for dense feature matrix `X`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't have it for the CV estimators?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The what's new is correct. Sample weight support is added only to ElasticNet and Lasso. Unfortunately, the logic in _coordinate_descent.py is such that this feature is not automatically inherited by ElasticNetCV and others. I do not know if there are advantages with the current structure.
In any case, I'd prefer to put adding sample weights to ElasticNetCV in another PR. What do you think?

@agramfort
Copy link
Member

agramfort commented Feb 14, 2020 via email

Copy link
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as well, thanks @lorentzenchr !

@rth rth merged commit 21686b7 into scikit-learn:master Feb 16, 2020
thomasjpfan pushed a commit to thomasjpfan/scikit-learn that referenced this pull request Feb 22, 2020
@lorentzenchr lorentzenchr deleted the enet_sample_weights branch February 29, 2020 10:03
panpiort8 pushed a commit to panpiort8/scikit-learn that referenced this pull request Mar 3, 2020
gio8tisu pushed a commit to gio8tisu/scikit-learn that referenced this pull request May 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants