Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Fix SGDClassifier never has the attribute "predict_proba" (even with log or modified_huber loss) #12222

Merged
merged 14 commits into from
Apr 19, 2019

Conversation

rebekahkim
Copy link
Contributor

Fixes #10113

  • predict_proba checks self.estimators_[0], not self.estimator
  • test for predict_proba

Copy link
Member

@amueller amueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@TomDLT
Copy link
Member

TomDLT commented Oct 1, 2018

Thanks @rebekahkim, but your test is not failing on master.
Actually, the bug seems to have been fixed in #10961.
I propose to close this PR and the corresponding issue.

@amueller
Copy link
Member

amueller commented Oct 1, 2018

Hm I'm not sure that means we can close the issue as I think there's a bug in the multioutput classifier as well?
I guess it depends on whether we require ducktyping to be available before fit?

But this bug here could theoretically be triggered if you grid-search between log-loss and hinge loss and then put the estimator into MultiOutputClassifier. Because then it's impossible to ducktype before fit has been called.

@amueller
Copy link
Member

amueller commented Oct 1, 2018

I wonder if this can also fail in a less obscure situation. @rebekahkim you could also just create a new estimator for the test that only has predict_proba after fitting. I don't think there's harm in checking this as late as possible.

@ogrisel
Copy link
Member

ogrisel commented Oct 1, 2018

I don't understand "create a new estimator for the test that only has predict_proba after fitting": this is already the case for SGDClassifier(loss='log'), no?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM but this would require an entry in doc/whats_new/v0.21.rst.

multi_target_linear.fit(X, y)
multi_target_linear.predict_proba(X)

sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a inline comment before this line to make it explicit that SGDClassifier uses loss='hinge' by default which is not a probabilistic loss function and therefore does not expose a predict_proba method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand "create a new estimator for the test that only has predict_proba after fitting": this is already the case for SGDClassifier(loss='log'), no?

No, as @TomDLT said above, this is not actually the case for SGDClassifier in master, and the current test is not failing in master.

@sergulaydore
Copy link
Contributor

Hello @rebekahkim ,

Thank you for participating in the WiMLDS/scikit sprint. We would love to merge all the PRs that were submitted. It would be great if you could follow up on the work that you started! For the PR you submitted, would you please update and re-submit? Please include #wimlds in your PR conversation.

Any questions:

  • see workflow for reference
  • ask on this PR conversation or the issue tracker
  • ask on wimlds gitter with a reference to this PR

cc: @reshamas

@reshamas
Copy link
Member

@rebekahkim
Will you be completing this PR?

@rebekahkim
Copy link
Contributor Author

I wonder if this can also fail in a less obscure situation. @rebekahkim you could also just create a new estimator for the test that only has predict_proba after fitting. I don't think there's harm in checking this as late as possible.

@amueller I'm having trouble finding estimators without the predict_proba attribute before fitting. Do you mind pointing me to the right direction?

@jnothman
Copy link
Member

jnothman commented Dec 25, 2018 via email

@reshamas
Copy link
Member

@psorianom @GaelVaroquaux
This PR has been languishing for quite some time, and I have been unable to find someone to complete it in our wimlds community. Can it be tagged "help wanted"?
Thank you.

@jnothman
Copy link
Member

I think only grid searches currently make their predict_proba appear after fitting, as we cannot be sure whether the best model is probabilistic:

GridSearchCV(Pipeline([('clf', None)]), {'clf': ['LogisticRegression()', 'RandomForestClassifier()']})

It's certainly a weird edge cas.e

@jnothman
Copy link
Member

Although maybe then I've not interpreted the raised issue correctly.

@rebekahkim
Copy link
Contributor Author

@jnothman I think you are right; it seems like SGDClassifier with appropriate loss (log or modified_huber) and SVC already has predict_proba before fit (at least in current master as well as in version 0.20.0).

As @amueller said, the bug would be triggered by a really obscure case. For example, if each estimator in estimators_ from

MultiOutputClassifier( GridSearchCV( 
       SGDClassifier(), param_grid = {'loss':('hinge', 'log', 'modified_huber')} ))

has loss = 'log' or 'modified_huber', only then can you have valid predict_proba and would fail MultiOutputClassifier's predict_proba check

def predict_proba(self, X):
        check_is_fitted(self, 'estimators_')
        if not hasattr(self.estimator, "predict_proba"): # would fail here
            raise ValueError(...)

in master.

This actually doesn't happen in iris, wine, breast_cancer, or digits datasets; the estimators don't line up (I tested them all). How should we go about doing this? Try to find a dataset where this happens? Or is there a way to "force" GridSearchCV to choose a certain parameter without setting it on the base estimator in the first place?

@jnothman
Copy link
Member

Well you can certainly force GridSearchCV to choose things if they're fake estimators... E.g. define score so that the one with proba returns 1 and the one without returns 0.

@rebekahkim
Copy link
Contributor Author

@jnothman that's really awesome; I didn't know you could do that!

I made the appropriate changes- do you mind taking a look and seeing if it's ready to merge?

@agramfort agramfort changed the title Fix SGDClassifier never has the attribute "predict_proba" (even with log or modified_huber loss) [MRG+1] Fix SGDClassifier never has the attribute "predict_proba" (even with log or modified_huber loss) Feb 28, 2019
Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do all sorts of contrived and unrealistic things in tests! :) Sometimes even in real code 😮

return 0.0
grid_clf = GridSearchCV(sgd_linear_clf, param_grid=param,
scoring=custom_scorer, cv=3, error_score=np.nan)
multi_target_linear = MultiOutputClassifier(grid_clf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to assert not hasattr(..., 'predict_proba') before doing this fit, so that the intention of the test is a bit clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean for multi_target_linear.estimator, right? Technically, the estimator still wouldn't have predict_proba after fit because the underlying estimator (SGDClassifier with default loss='hinge') doesn't have predict_proba. But all estimators in estimators_ here would (after fit, of course).

If you mean for the multi_target_linear itself, it would have predict_proba before and after fit; it just won't be valid (raises ValueError)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman thoughts?

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've realised what my issue is here. I think this PR is an improvement, and we can merge it as a quick fix, but what we should really be doing here is defining predict_proba such that hasattr(multioutputclf, 'predict_proba') is False if the underlying estimator does not have a predict_proba attribute. See BaseSearchCV.predict_proba for instance. Would you like to help implementing that, @rebekahkim?

@rebekahkim
Copy link
Contributor Author

@jnothman I'd like to help with the implementation! I just want to make sure I'm understanding what you're saying and proposing a correct solution.

We don't want a multioutput class instance, with base estimator(s) that doesn't have a predict_proba, to have the predict_proba property: i.e. hasattr(clf, 'predict_proba') returns False.
I need to do some more digging, but it seems like BaseSearchCV is taking advantage of the if_delegate_has_method decorator to do this and handle the hasattr for sub-estimator(s). I assume we can do something similar with predict_proba in the multioutput classifier. Is this the right approach?

As a side note, while searching for some examples, I saw that SVC might also want a fix for its predict_proba check. See link.

@jnothman
Copy link
Member

jnothman commented Apr 11, 2019 via email

@NicolasHug
Copy link
Member

NicolasHug commented Apr 12, 2019

We've been looking at this with @thomasjpfan .

You won't be able to use @if_delegate_has_method since you'd need to pass it self.estimators_[0], and self.estimators_[0] isn't an attribute (it's just the first element of an attribute which is a list).

We think the right solution here is to mimic what SVC is doing: define predict_proba as a property and raise an AttributeError if predict_proba isn't defined in any of the estimators in self.estimators_, or in self.estimator.

This way, hasattr(multioutputclf, 'predict_proba') behaves properly


All that being said, as Joel said we'd be fine merging the PR as-is as a good-enough fix for now, and open another issue to address the hasattr matter.

@reshamas
Copy link
Member

@rebekahkim We have scheduled the 2019 WiMLDS sprint for Sunday, Aug 25, if you would like to schedule the work up to and around that date.
cc: @NicolasHug @thomasjpfan

@rebekahkim
Copy link
Contributor Author

@NicolasHug @thomasjpfan Good point!

I'll leave the decision up to the sklearn dev team whether to merge this PR (@jnothman). I'll open a new issue to correct the hasattr behavior; let me do some more digging and testing and look into SVC- Thanks for the pointer!

@jnothman
Copy link
Member

We think the right solution here is to mimic what SVC is doing: define predict_proba as a property and raise an AttributeError if predict_proba isn't defined in any of the estimators in self.estimators_, or in self.estimator. This way, hasattr(multioutputclf, 'predict_proba') behaves properly

That's fine. An alternative (not sure which is more elegant): you could use if_delegate_has_method as long as you define

    @property
    def _example_fitted_estimator(self):
        return self.estimators_[0]

    @if_delegate_has_method(['_example_fitted_estimator', 'estimator'])
    def predict_proba(self, X):

@jnothman jnothman added this to the 0.21 milestone Apr 17, 2019
@jnothman
Copy link
Member

Please add a |Fix| entry to the change log at doc/whats_new/v0.21.rst. Like the other entries there, please reference this pull request with :issue: and credit yourself (and other contributors if applicable) with :user:

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise

sklearn/tests/test_multioutput.py Outdated Show resolved Hide resolved
Co-Authored-By: rebekahkim <rebekah.kim@columbia.edu>
@rebekahkim
Copy link
Contributor Author

@NicolasHug thanks for the style suggestion - it seems like codecov/patch check fails because I've made changes to documentation. Can we ignore this?

@NicolasHug
Copy link
Member

As far as I can tell the proposed changes are tested so it should be fine.

Merging, thanks @rebekahkim !

@NicolasHug NicolasHug merged commit d903436 into scikit-learn:master Apr 19, 2019
@rebekahkim rebekahkim deleted the moc-predict-proba branch April 20, 2019 02:52
jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Apr 25, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SGDClassifier never has the attribute "predict_proba" (even with log or modified_huber loss)
10 participants