New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG + 2] predict_proba should use the softmax function in the multinomial case #5182

Merged
merged 3 commits into from Aug 30, 2015

Conversation

Projects
None yet
7 participants
@MechCoder
Member

MechCoder commented Aug 28, 2015

Fixes #5176

@@ -238,16 +238,32 @@ def _predict_proba_lr(self, X):
1. / (1. + np.exp(-self.decision_function(X)));
multiclass is handled by normalizing that over all classes.
"""
from sklearn.linear_model.logistic import (
LogisticRegression, LogisticRegressionCV)

This comment has been minimized.

@mblondel

mblondel Aug 29, 2015

Member

That's pretty ugly. I would rather override predict_proba in LogisticRegression.

def predict_proba(self, X):
    if self.multiclass == "multinomial":
        [...]
    else:
        return super(LogisticRegression, self).predict_proba(X)

This comment has been minimized.

@larsmans

larsmans Aug 29, 2015

Member

That should get rid of the isinstance on self as well.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 30, 2015

@mblondel Soory for the hasty hack. I've fixed up your comment.

@mblondel

This comment has been minimized.

Member

mblondel commented Aug 30, 2015

I guess predict_proba's argmax = predict is already tested in the common tests?

LGTM.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 30, 2015

yes. it is

@MechCoder MechCoder changed the title from [BUG] predict_proba should use the softmax function in the multinomial case to [MRG + 1] predict_proba should use the softmax function in the multinomial case Aug 30, 2015

@mblondel

This comment has been minimized.

Member

mblondel commented Aug 30, 2015

Does the test you added fails without the patch?

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 30, 2015

I was thinking about that. Would it be sufficient to check that the predicted probability values are different for both the cases?

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 30, 2015

It is always true that clf.predict_proba(X).max(axis=0) is greater for the multinomial case?

@mblondel

This comment has been minimized.

Member

mblondel commented Aug 30, 2015

You could try to compute the multinomial log loss:
http://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html

Hopefully, the loss should be smaller with the right probabilities (although this might not be true due the l2 regularization term on the coefficients...).

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 30, 2015

thanks for the tip. I've added the test.

@agramfort

This comment has been minimized.

Member

agramfort commented Aug 30, 2015

+1 for merge when travis is happy. thanks @MechCoder

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 30, 2015

Ah, I see. I added that, but I kept the previous test as well because I thought it might be interesting.

@mblondel

This comment has been minimized.

Member

mblondel commented Aug 30, 2015

Thanks. LGTM now :)

@MechCoder

This comment has been minimized.

Member

MechCoder commented Aug 30, 2015

Also fixes #5134

@GaelVaroquaux GaelVaroquaux changed the title from [MRG + 1] predict_proba should use the softmax function in the multinomial case to [MRG + 2] predict_proba should use the softmax function in the multinomial case Aug 30, 2015

@GaelVaroquaux

This comment has been minimized.

Member

GaelVaroquaux commented Aug 30, 2015

Two +1s. We're only waiting for Appveyor (which is currently very slow).

GaelVaroquaux added a commit that referenced this pull request Aug 30, 2015

Merge pull request #5182 from MechCoder/predict_proba_fix
[MRG + 2] predict_proba should use the softmax function in the multinomial case

@GaelVaroquaux GaelVaroquaux merged commit 4f713ce into scikit-learn:master Aug 30, 2015

2 checks passed

continuous-integration/appveyor AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@akxlr

This comment has been minimized.

akxlr commented Aug 31, 2015

Minor comment: would it be more stable to calculate the log probability (using scipy.misc.logsumexp for the denominator) then exponentiate? At least for predict_log_proba this could be done.

@jagapiou

This comment has been minimized.

jagapiou commented Sep 7, 2015

This will have overflow issues for large decision_function values (e.g. [750, 749, 748]). This can be fixed by subtracting the max from the output of decision_function (in this case 750), since:
exp(x_i) / sum_k{ exp(x_i) } = exp(x_i - k) exp(k) / sum_k{ exp(x_i - k) exp(k) } = exp(x_i - k) / sum_k{ exp(x_i - k) }.

@MechCoder MechCoder deleted the MechCoder:predict_proba_fix branch Sep 8, 2015

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 8, 2015

Thanks for the top ! See #5225

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment