Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] FIX Add l1_ratio to set regularization with elasticnet penalty in Perceptron #18622

Merged
merged 11 commits into from Nov 2, 2020
6 changes: 6 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -25,6 +25,7 @@ random sampling procedures.
- |Fix| :class:`decomposition.KernelPCA` behaviour is now more consistent
between 32-bits and 64-bits data when the kernel has small positive
eigenvalues.
- |Fix| ::class:`linear_model.Perceptron` when penalty='elasticnet'

Details are listed in the changelog below.

Expand Down Expand Up @@ -393,6 +394,11 @@ Changelog
`X_offset_` and `X_scale_` were undefined.
:pr:`18607` by :user:`fhaselbeck <fhaselbeck>`.

- |Fix| Added the missing `l1_ratio` parameter in
:class:`linear_model.Perceptron`, to be used when `penalty='elasticnet'`.
This changes the default from 0 to 0.15. :pr:`18622` by
:user:`Haesun Park <rickiepark>`.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

:mod:`sklearn.manifold`
.......................

Expand Down
12 changes: 10 additions & 2 deletions sklearn/linear_model/_perceptron.py
Expand Up @@ -20,6 +20,13 @@ class Perceptron(BaseSGDClassifier):
Constant that multiplies the regularization term if regularization is
used.

l1_ratio : float, default=0.15
The Elastic Net mixing parameter, with `0 <= l1_ratio <= 1`.
`l1_ratio=0` corresponds to L2 penalty, `l1_ratio=1` to L1.
Only used if `penalty`='elasticnet'.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 0.24

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add .. versionadded:: 0.24 here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added versionadded. Thanks :)

fit_intercept : bool, default=True
Whether the intercept should be estimated or not. If False, the
data is assumed to be already centered.
Expand Down Expand Up @@ -148,13 +155,14 @@ class Perceptron(BaseSGDClassifier):
https://en.wikipedia.org/wiki/Perceptron and references therein.
"""
@_deprecate_positional_args
def __init__(self, *, penalty=None, alpha=0.0001, fit_intercept=True,
def __init__(self, *, penalty=None, alpha=0.0001, l1_ratio=0.15,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately introducing l1_ratio=.15 might produce changes in the results of users who had set penalty='elasticnet', since it used to be hardcoded to 0 before.

We can:

  • set the default to 0, but that would be slightly weird since just setting penalty='elasticnet' would still be using L2. It would also be inconsistent with SGDClassifier.
  • set the default to 'warn' which is equivalent to 0, but indicates that the default will change from 0 to .15 in version 0.26
  • not care about all this and consider that this is a bugfix. In this case we can just set the default as it is: 0.15

I have a slight preference for option 3 since it's the simplest and it's likely that this won't affect many people.

Thoughts @rth @glemaitre @ogrisel @thomasjpfan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also in favor of option 3 as tend to perceive it as a bugfix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why 0.15 as default, not, for instance, 0.5?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also slightly prefer option 3 and consider this a bug fix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, seems like a bugfix. And the default should be consistent with SGDClassifier, which has 0.15, probably for no real reason.

fit_intercept=True,
max_iter=1000, tol=1e-3, shuffle=True, verbose=0, eta0=1.0,
n_jobs=None, random_state=0, early_stopping=False,
validation_fraction=0.1, n_iter_no_change=5,
class_weight=None, warm_start=False):
super().__init__(
loss="perceptron", penalty=penalty, alpha=alpha, l1_ratio=0,
loss="perceptron", penalty=penalty, alpha=alpha, l1_ratio=l1_ratio,
fit_intercept=fit_intercept, max_iter=max_iter, tol=tol,
shuffle=shuffle, verbose=verbose, random_state=random_state,
learning_rate="constant", eta0=eta0, early_stopping=early_stopping,
Expand Down
11 changes: 11 additions & 0 deletions sklearn/linear_model/tests/test_perceptron.py
Expand Up @@ -67,3 +67,14 @@ def test_undefined_methods():
clf = Perceptron(max_iter=100)
for meth in ("predict_proba", "predict_log_proba"):
assert_raises(AttributeError, lambda x: getattr(clf, x), meth)


def test_perceptron_l1_ratio():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get a better test?

Also, I feel like the perceptron documentation could probably be improved. Our perceptron class doesn't really implement only a perceptron. Has anyone else implemented something like this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I'm an example to follow on this, but let's try to be more descriptive in our comments ;) What kind of test would you like to see @amueller?

glemaitre marked this conversation as resolved.
Show resolved Hide resolved
"""Check that `l1_ratio` has an impact when `penalty='elasticnet'`"""
clf1 = Perceptron(l1_ratio=0, penalty='elasticnet')
clf1.fit(X, y)

clf2 = Perceptron(l1_ratio=0.15, penalty='elasticnet')
clf2.fit(X, y)

assert clf1.score(X, y) != clf2.score(X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to assert that the scores are different with a significant difference (e.g. something like .1 or .05).

assert clf1.score(X, y) != clf2.score(X, y) would pass even if the score difference was 1e-15 and that's not strict enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clf1.score(X, y) is 0.96 and clf2.score(X, y) is 0.906.
How about assert clf1.score(X, y) - clf2.score(X, y) > 0.01 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds fine. Maybe check the absolute difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clf1 is overfitted more than clf2. Is it normal that clf2.score(X, y) is greater than clf1.score(X, y)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have l1_ratio=0 and check with penalty="l2" and l1_ratio=1 and check with penalty="l1"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @glemaitre ,
Do you mean something like
assert Perceptron(l1_ratio=0, penalty='elasticnet').fit(X, y).score(X, y) != Perceptron(l1_ratio=0, penalty='l2').fit(X, y).score(X, y)
and
assert Perceptron(l1_ratio=1, penalty='elasticnet').fit(X, y).score(X, y) != Perceptron(l1_ratio=1, penalty='l1').fit(X, y).score(X, y)?

Copy link
Member

@glemaitre glemaitre Nov 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a small test to check that this is passing across all platforms.
We should be able to check the coefficient in l1 and with l1_ratio=1 and l2 with l1_ratio=0