Skip to content

Commit

Permalink
fixed bug for precision recall curve when all labels are negative
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Jul 12, 2017
1 parent cb1b6c4 commit b0b9bca
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion sklearn/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
sample_weight=sample_weight)

precision = tps / (tps + fps)
recall = tps / tps[-1]
recall = np.ones(tps.size) if tps[-1] == 0 else tps / tps[-1]

# stop when full recall attained
# and reverse the outputs so recall is decreasing
Expand Down
27 changes: 15 additions & 12 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,17 @@ def test_roc_curve_toydata():
y_true = [0, 0]
y_score = [0.25, 0.75]
# assert UndefinedMetricWarning because of no positive sample in y_true
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve, y_true, y_score)
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve,
y_true, y_score)
assert_raises(ValueError, roc_auc_score, y_true, y_score)
assert_array_almost_equal(tpr, [0., 0.5, 1.])
assert_array_almost_equal(fpr, [np.nan, np.nan, np.nan])

y_true = [1, 1]
y_score = [0.25, 0.75]
# assert UndefinedMetricWarning because of no negative sample in y_true
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve, y_true, y_score)
tpr, fpr, _ = assert_warns(UndefinedMetricWarning, roc_curve,
y_true, y_score)
assert_raises(ValueError, roc_auc_score, y_true, y_score)
assert_array_almost_equal(tpr, [np.nan, np.nan])
assert_array_almost_equal(fpr, [0.5, 1.])
Expand Down Expand Up @@ -565,8 +567,9 @@ def test_precision_recall_curve_toydata():

y_true = [0, 0]
y_score = [0.25, 0.75]
assert_raises(Exception, precision_recall_curve, y_true, y_score)
assert_raises(Exception, average_precision_score, y_true, y_score)
p, r, _ = precision_recall_curve(y_true, y_score)
assert_array_equal(p, np.array([0.0, 1.0]))
assert_array_equal(r, np.array([1.0, 0.0]))

y_true = [1, 1]
y_score = [0.25, 0.75]
Expand All @@ -578,21 +581,21 @@ def test_precision_recall_curve_toydata():
# Multi-label classification task
y_true = np.array([[0, 1], [0, 1]])
y_score = np.array([[0, 1], [0, 1]])
assert_raises(Exception, average_precision_score, y_true, y_score,
average="macro")
assert_raises(Exception, average_precision_score, y_true, y_score,
average="weighted")
assert_almost_equal(average_precision_score(y_true, y_score,
average="macro"), 0.75)
assert_almost_equal(average_precision_score(y_true, y_score,
average="weighted"), 1.0)
assert_almost_equal(average_precision_score(y_true, y_score,
average="samples"), 1.)
assert_almost_equal(average_precision_score(y_true, y_score,
average="micro"), 1.)

y_true = np.array([[0, 1], [0, 1]])
y_score = np.array([[0, 1], [1, 0]])
assert_raises(Exception, average_precision_score, y_true, y_score,
average="macro")
assert_raises(Exception, average_precision_score, y_true, y_score,
average="weighted")
assert_almost_equal(average_precision_score(y_true, y_score,
average="macro"), 0.75)
assert_almost_equal(average_precision_score(y_true, y_score,
average="weighted"), 1.0)
assert_almost_equal(average_precision_score(y_true, y_score,
average="samples"), 0.75)
assert_almost_equal(average_precision_score(y_true, y_score,
Expand Down

0 comments on commit b0b9bca

Please sign in to comment.