Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX limit warnings for recall_score, precision_score, f1_score, #2592

Merged
merged 3 commits into from Dec 5, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 26 additions & 9 deletions sklearn/metrics/metrics.py
Expand Up @@ -1315,11 +1315,12 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
beta=beta,
labels=labels,
pos_label=pos_label,
average=average)
average=average,
warn_for=('f-score',))
return f


def _prf_divide(numerator, denominator, metric, modifier, average):
def _prf_divide(numerator, denominator, metric, modifier, average, warn_for):
"""Performs division and handles divide-by-zero.

On zero-division, sets the corresponding result elements to zero
Expand All @@ -1344,8 +1345,17 @@ def _prf_divide(numerator, denominator, metric, modifier, average):
if average == 'samples':
axis0, axis1 = axis1, axis0

msg = ('{0} and F-score are ill-defined and being set to 0.0 {{0}} '
'no {1} {2}s.'.format(metric.title(), modifier, axis0))
if metric in warn_for and 'f-score' in warn_for:
msg_start = '{0} and F-score are'.format(metric.title())
elif metric in warn_for:
msg_start = '{0} is'.format(metric.title())
elif 'f-score' in warn_for:
msg_start = 'F-score is'
else:
return result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it doesn't raise a ValueError here if metric is not in warn_for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what we're handling: 'precision' won't be in warn_for when we're calling recall_score. The output of this division will be ignored anyway (and indeed, we could avoid that further up stream).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it doesn't raise a ValueError here if metric is not in warn_for?

This the way to disable warnings as you requested earlier isn't it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to what happened when somebody put some garbage string in warn_for and thought that it might generate an error. But the current behaviour is also fine.


msg = ('{0} ill-defined and being set to 0.0 {{0}} '
'no {1} {2}s.'.format(msg_start, modifier, axis0))
if len(mask) == 1:
msg = msg.format('due to')
else:
Expand All @@ -1355,7 +1365,9 @@ def _prf_divide(numerator, denominator, metric, modifier, average):


def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
pos_label=1, average=None):
pos_label=1, average=None,
warn_for=('precision', 'recall',
'f-score')):
"""Compute precision, recall, F-measure and support for each class

The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
Expand Down Expand Up @@ -1419,6 +1431,9 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
meaningful for multilabel classification where this differs from
:func:`accuracy_score`).

warn_for : tuple or set, for internal use
This determines which warnings will be made in the case that this
function is being used to return only one of its metrics.

Returns
-------
Expand Down Expand Up @@ -1547,9 +1562,9 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
# Oddly, we may get an "invalid" rather than a "divide" error
# here.
precision = _prf_divide(tp_sum, pred_sum,
'precision', 'predicted', average)
'precision', 'predicted', average, warn_for)
recall = _prf_divide(tp_sum, true_sum,
'recall', 'true', average)
'recall', 'true', average, warn_for)
# Don't need to warn for F: either P or R warned, or tp == 0 where pos
# and true are nonzero, in which case, F is well-defined and zero
f_score = ((1 + beta2) * precision * recall /
Expand Down Expand Up @@ -1654,7 +1669,8 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
p, _, _, _ = precision_recall_fscore_support(y_true, y_pred,
labels=labels,
pos_label=pos_label,
average=average)
average=average,
warn_for=('precision',))
return p


Expand Down Expand Up @@ -1729,7 +1745,8 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
_, r, _, _ = precision_recall_fscore_support(y_true, y_pred,
labels=labels,
pos_label=pos_label,
average=average)
average=average,
warn_for=('recall',))
return r


Expand Down
53 changes: 53 additions & 0 deletions sklearn/metrics/tests/test_metrics.py
Expand Up @@ -23,6 +23,7 @@
assert_array_equal,
assert_array_almost_equal,
assert_warns,
assert_no_warnings,
assert_greater,
ignore_warnings)

Expand Down Expand Up @@ -1886,6 +1887,58 @@ def test_prf_warnings():
'being set to 0.0 due to no true samples.')


def test_recall_warnings():
assert_no_warnings(recall_score,
np.array([[1, 1], [1, 1]]),
np.array([[0, 0], [0, 0]]),
average='micro')

with warnings.catch_warnings(record=True) as record:
warnings.simplefilter('always')
recall_score(np.array([[0, 0], [0, 0]]),
np.array([[1, 1], [1, 1]]),
average='micro')
assert_equal(str(record.pop().message),
'Recall is ill-defined and '
'being set to 0.0 due to no true samples.')


def test_precision_warnings():
with warnings.catch_warnings(record=True) as record:
warnings.simplefilter('always')

precision_score(np.array([[1, 1], [1, 1]]),
np.array([[0, 0], [0, 0]]),
average='micro')
assert_equal(str(record.pop().message),
'Precision is ill-defined and '
'being set to 0.0 due to no predicted samples.')

assert_no_warnings(precision_score,
np.array([[0, 0], [0, 0]]),
np.array([[1, 1], [1, 1]]),
average='micro')


def test_fscore_warnings():
with warnings.catch_warnings(record=True) as record:
warnings.simplefilter('always')

for score in [f1_score, partial(fbeta_score, beta=2)]:
score(np.array([[1, 1], [1, 1]]),
np.array([[0, 0], [0, 0]]),
average='micro')
assert_equal(str(record.pop().message),
'F-score is ill-defined and '
'being set to 0.0 due to no predicted samples.')
score(np.array([[0, 0], [0, 0]]),
np.array([[1, 1], [1, 1]]),
average='micro')
assert_equal(str(record.pop().message),
'F-score is ill-defined and '
'being set to 0.0 due to no true samples.')


def test__check_clf_targets():
"""Check that _check_clf_targets correctly merges target types, squeezes
output and fails if input lengths differ."""
Expand Down