Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG Fixes error with multiclass roc auc scorer #15274

Merged
merged 10 commits into from Nov 2, 2019

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Oct 16, 2019

Adds multiclass support for threshold metric for roc_auc_score

@thomasjpfan thomasjpfan added this to the 0.22 milestone Oct 25, 2019
@@ -323,6 +323,12 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
self._score_func.__name__))
elif isinstance(y_pred, list):
y_pred = np.vstack([p[:, -1] for p in y_pred]).T
else: # multiclass
try:
y_pred = method_caller(clf, "predict_proba", X)
Copy link
Member

@qinhanmin2014 qinhanmin2014 Oct 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we try decision_function first? _ThresholdScorer means that we "Evaluate decision function output for X relative to y_true".

Copy link
Member

@ogrisel ogrisel Oct 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, _ThresholdScorer should be consistent and always try to use decision_function first whatever the type of output: it's should work for any kind model that has un-normalized but continuous class assignment scores.

If ROC AUC with multiclass averaging needs normalized class assignment probabilities instead, we should use this requirement in the defintions of the scorer instances instead:

Currently, scorer.py has:

# Score functions that need decision values
roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
                             needs_threshold=True)
average_precision_scorer = make_scorer(average_precision_score,
                                       needs_threshold=True)
roc_auc_ovo_scorer = make_scorer(roc_auc_score, needs_threshold=True,
                                 multi_class='ovo')
roc_auc_ovo_weighted_scorer = make_scorer(roc_auc_score, needs_threshold=True,
                                          multi_class='ovo',
                                          average='weighted')
roc_auc_ovr_scorer = make_scorer(roc_auc_score, needs_threshold=True,
                                 multi_class='ovr')
roc_auc_ovr_weighted_scorer = make_scorer(roc_auc_score, needs_threshold=True,
                                          multi_class='ovr',
                                          average='weighted')

If needed, we should replace needs_threshold=True by needs_proba=True but I do not see why this would be the case for ROC AUC.

Copy link
Member

@qinhanmin2014 qinhanmin2014 Oct 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If needed, we should replace needs_threshold=True by needs_proba=True but I do not see why this would be the case for ROC AUC.

I see, so we should use need_proba=True, we do not consider decision_function in _ProbaScorer.

For roc_auc_scorer, I think we should use _ThresholdScorer becasue we accept the output from decision_function.

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Oct 29, 2019

In my previous comment I said:

If needed, we should replace needs_threshold=True by needs_proba=True but I do not see why this would be the case for ROC AUC.

But are you sure we really need to use proba for multiclass ROC AUC scores? Why couldn't we use unnormalized class assignment scores also in the multiclass case?

@ogrisel
Copy link
Member

@ogrisel ogrisel commented Oct 29, 2019

Replying to myself: when calling _multiclass_roc_auc_score on un-normalized decision thresholds, one gets the following message:

ValueError: Target scores need to be probabilities for multiclass roc_auc, i.e. they should sum up to 1.0 over classes

So indeed the fix implemented in this PR is correct.

Copy link
Member

@ogrisel ogrisel left a comment

Please don't forget to add an entry to the changelog.

Also, a few more nits:

sklearn/metrics/tests/test_score_objects.py Show resolved Hide resolved
sklearn/metrics/tests/test_score_objects.py Outdated Show resolved Hide resolved
@adrinjalali adrinjalali added this to In progress in Meeting Issues via automation Oct 29, 2019
@adrinjalali adrinjalali moved this from In progress to Review in progress in Meeting Issues Oct 29, 2019
Copy link
Member

@ogrisel ogrisel left a comment

LGTM!

Meeting Issues automation moved this from Review in progress to Reviewer approved Oct 30, 2019
@ogrisel
Copy link
Member

@ogrisel ogrisel commented Oct 30, 2019

@qinhanmin2014 any further comment?

@@ -493,6 +493,10 @@ Changelog
``multioutput`` parameter.
:pr:`14732` by :user:`Agamemnon Krasoulis <agamemnonc>`.

- |Fix| The scorers: 'roc_auc_ovr', 'roc_auc_ovo', 'roc_auc_ovr_weighted',
Copy link
Member

@jnothman jnothman Oct 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these were never released, it doesn't really make sense to advertise as a separate change log entry does it?

Copy link
Member

@ogrisel ogrisel Oct 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had not realized that. Indeed we should remove that changelog entry.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Oct 30, 2019

Moved the mentioning of the scores to the entry about the introduction of the multiclass roc_auc metric.

@qinhanmin2014
Copy link
Member

@qinhanmin2014 qinhanmin2014 commented Oct 31, 2019

I think there's still an annoying issue here.
In roc_auc_score, we have

if y_type == "multiclass" or (y_type == "binary" and
                              y_score.ndim == 2 and
                              y_score.shape[1] > 2):
    # multiclass case
else:
    # binary case and multilabel indicator case

which means that when y_type is binary, this can still be a multiclass problem, but in _ProbaScorer

if y_type == "binary":
    if y_pred.shape[1] == 2:
        y_pred = y_pred[:, 1]
    else:
        raise ValueError(...)

So we'll get a ValueError.

@qinhanmin2014
Copy link
Member

@qinhanmin2014 qinhanmin2014 commented Oct 31, 2019

One possible solution is to remove input validation in Scorer. We can do input validation in the function where we calculate the score.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Oct 31, 2019

For reference, the weird condition on roc_auc_score is to support the following use case:

import numpy as np
from sklearn.metrics import roc_auc_score

y_true = np.array([0, 1, 0, 1])

y_score = np.array([[0.1 , 0.8 , 0.1 ],
                    [0.3 , 0.4 , 0.3 ],
                    [0.35, 0.5 , 0.15],
                    [0.  , 0.2 , 0.8 ]])

roc_auc_score(y_true, y_score, labels=[0, 1, 2], multi_class='ovo')

Without the labels this will fail. I have a feeling we need to think of a better API for this. (#12385)

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Oct 31, 2019

Even if we remove the restriction on _ProbaScorer a user would need to create their own scorer for this to work:

scorer = make_scorer(roc_auc_score, multi_class='ovo',
                        labels=[0, 1, 2], needs_proba=True)
X, y = make_classification(n_classes=3, n_informative=3, n_samples=20,
                            random_state=0)
lr = LogisticRegression(multi_class="multinomial").fit(X, y)
scorer(lr, X, y == 0)

@qinhanmin2014
Copy link
Member

@qinhanmin2014 qinhanmin2014 commented Nov 1, 2019

@thomasjpfan When defining built-in scorers, we often use default value of the parameters. If users do not want to use default value of the parameters, they'll need to define a scorer themselves through make_scorer. I think it's reasonable.

In multiclass roc_auc_score, we infer the labels by default. When y_true do not contain all the labels, it's impossible to imfer the labels, so it's reasonble to require users to define a scorer themselves.

Current issue is that when y_true only contains two labels, users can't define a scorer themselves because of the input validation in _ProbaScorer, so I think we need to remove that.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Nov 1, 2019

Updated the check in _ProbaScorer to check if y_pred looks multiclass before raising the error.

Edit: If it does not look multiclass, then it raises the ValueError.

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Maybe update _ThresholdScorer simultaneous?

@@ -247,7 +247,7 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
if y_type == "binary":
if y_pred.shape[1] == 2:
y_pred = y_pred[:, 1]
else:
elif y_pred.shape[1] == 1: # not multiclass
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why is this useful?

Copy link
Member Author

@thomasjpfan thomasjpfan Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the blame, this was added in #12486 to resolve #7598

It looks like it was trying to get a better error message for the y_pred.shape[1]==1 case.

@pytest.mark.parametrize('scorer_name', [
'roc_auc_ovr', 'roc_auc_ovo',
'roc_auc_ovr_weighted', 'roc_auc_ovo_weighted'])
def test_multiclass_roc_no_proba_scorer_errors(scorer_name):
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please tell me why multiclass roc_auc_score do not support the output of decision_function? thanks

Copy link
Member Author

@thomasjpfan thomasjpfan Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The paper this was based on used probabilities for ranking: https://link.springer.com/content/pdf/10.1023%2FA%3A1010920819831.pdf

For reference this was discussed in the original issue: #7663 (comment)

@qinhanmin2014 qinhanmin2014 merged commit 96c1a5b into scikit-learn:master Nov 2, 2019
20 checks passed
Meeting Issues automation moved this from Reviewer approved to Done Nov 2, 2019
@qinhanmin2014
Copy link
Member

@qinhanmin2014 qinhanmin2014 commented Nov 2, 2019

thank, though I still think that multiclass roc_auc can accept the output of decision_function. Not all estimators in scikit-learn has predict_proba :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants