Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] FIX Add l1_ratio to set regularization with elasticnet penalty in Perceptron #18622

Merged
merged 11 commits into from Nov 2, 2020
10 changes: 8 additions & 2 deletions sklearn/linear_model/_perceptron.py
Expand Up @@ -20,6 +20,11 @@ class Perceptron(BaseSGDClassifier):
Constant that multiplies the regularization term if regularization is
used.

l1_ratio : float, default=0.15
The Elastic Net mixing parameter, with 0 <= l1_ratio <= 1.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Only used if `penalty` is 'elasticnet'.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add .. versionadded:: 0.24 here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added versionadded. Thanks :)

fit_intercept : bool, default=True
Whether the intercept should be estimated or not. If False, the
data is assumed to be already centered.
Expand Down Expand Up @@ -148,13 +153,14 @@ class Perceptron(BaseSGDClassifier):
https://en.wikipedia.org/wiki/Perceptron and references therein.
"""
@_deprecate_positional_args
def __init__(self, *, penalty=None, alpha=0.0001, fit_intercept=True,
def __init__(self, *, penalty=None, alpha=0.0001, l1_ratio=0.15,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately introducing l1_ratio=.15 might produce changes in the results of users who had set penalty='elasticnet', since it used to be hardcoded to 0 before.

We can:

  • set the default to 0, but that would be slightly weird since just setting penalty='elasticnet' would still be using L2. It would also be inconsistent with SGDClassifier.
  • set the default to 'warn' which is equivalent to 0, but indicates that the default will change from 0 to .15 in version 0.26
  • not care about all this and consider that this is a bugfix. In this case we can just set the default as it is: 0.15

I have a slight preference for option 3 since it's the simplest and it's likely that this won't affect many people.

Thoughts @rth @glemaitre @ogrisel @thomasjpfan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also in favor of option 3 as tend to perceive it as a bugfix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why 0.15 as default, not, for instance, 0.5?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also slightly prefer option 3 and consider this a bug fix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, seems like a bugfix. And the default should be consistent with SGDClassifier, which has 0.15, probably for no real reason.

fit_intercept=True,
max_iter=1000, tol=1e-3, shuffle=True, verbose=0, eta0=1.0,
n_jobs=None, random_state=0, early_stopping=False,
validation_fraction=0.1, n_iter_no_change=5,
class_weight=None, warm_start=False):
super().__init__(
loss="perceptron", penalty=penalty, alpha=alpha, l1_ratio=0,
loss="perceptron", penalty=penalty, alpha=alpha, l1_ratio=l1_ratio,
fit_intercept=fit_intercept, max_iter=max_iter, tol=tol,
shuffle=shuffle, verbose=verbose, random_state=random_state,
learning_rate="constant", eta0=eta0, early_stopping=early_stopping,
Expand Down