Skip to content

Commit

Permalink
[MRG+1] FIX precision/recall/f1-score for truncated range(n_labels) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gxyd authored and glemaitre committed Jan 11, 2018
1 parent c5706e6 commit 60b0cf8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ Decomposition, manifold learning and clustering

Metrics

- Fixed a bug in :func:`metrics.precision_precision_recall_fscore_support`
when truncated `range(n_labels)` is passed as value for `labels`.
:issue:`10377` by :user:`Gaurav Dhingra <gxyd>`.

- Fixed a bug due to floating point error in :func:`metrics.roc_auc_score` with
non-integer sample weights. :issue:`9786` by :user:`Hanmin Qin <qinhanmin2014>`.

Expand Down
1 change: 1 addition & 0 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
raise ValueError('All labels must be in [0, n labels). '
'Got %d < 0' % np.min(labels))

if n_labels is not None:
y_true = y_true[:, labels[:n_labels]]
y_pred = y_pred[:, labels[:n_labels]]

Expand Down
8 changes: 8 additions & 0 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def test_precision_recall_f_extra_labels():
assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
labels=np.arange(-1, 4), average=average)

# tests non-regression on issue #10307
y_true = np.array([[0, 1, 1], [1, 0, 0]])
y_pred = np.array([[1, 1, 1], [1, 0, 1]])
p, r, f, _ = precision_recall_fscore_support(y_true, y_pred,
average='samples',
labels=[0, 1])
assert_almost_equal(np.array([p, r, f]), np.array([3 / 4, 1, 5 / 6]))


@ignore_warnings
def test_precision_recall_f_ignored_labels():
Expand Down

0 comments on commit 60b0cf8

Please sign in to comment.