Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Feb 4, 2017
1 parent 92f129d commit 50fbd04
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 43 deletions.
1 change: 0 additions & 1 deletion sklearn/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
sample_weight=sample_weight)

precision = tps / (tps + fps)
# recall = tps / tps[-1]
recall = np.ones(tps.size) if tps[-1] == 0 else tps / tps[-1]

# stop when full recall attained
Expand Down
59 changes: 17 additions & 42 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,11 @@ def test_precision_recall_curve_toydata():
assert_array_almost_equal(r, [1, 0.])
assert_almost_equal(auc_prc, .75)

# y_true = [0, 0]
# y_score = [0.25, 0.75]
# assert_raises(Exception, precision_recall_curve, y_true, y_score)
# assert_raises(Exception, average_precision_score, y_true, y_score)
y_true = [0, 0]
y_score = [0.25, 0.75]
p, r, _ = precision_recall_curve(y_true, y_score)
assert_array_equal(p, np.array([0.0, 1.0]))
assert_array_equal(r, np.array([1.0, 0.0]))

y_true = [1, 1]
y_score = [0.25, 0.75]
Expand All @@ -549,24 +550,21 @@ def test_precision_recall_curve_toydata():
assert_array_almost_equal(r, [1, 0.5, 0.])

# Multi-label classification task
# y_true = np.array([[0, 1], [0, 1]])
# y_score = np.array([[0, 1], [0, 1]])
# assert_raises(Exception, average_precision_score, y_true, y_score,
# average="macro")
# assert_raises(Exception, average_precision_score, y_true, y_score,
# average="weighted")
# assert_almost_equal(average_precision_score(y_true, y_score,
# average="samples"), 1.)
# assert_almost_equal(average_precision_score(y_true, y_score,
# average="micro"), 1.)
#
y_true = np.array([[0, 1], [0, 1]])
y_score = np.array([[0, 1], [0, 1]])
assert_almost_equal(average_precision_score(y_true, y_score,
average="macro"), 0.75)
assert_almost_equal(average_precision_score(y_true, y_score,
average="weighted"), 1.0)
assert_almost_equal(average_precision_score(y_true, y_score,
average="samples"), 1.)
assert_almost_equal(average_precision_score(y_true, y_score,
average="micro"), 1.)

y_true = np.array([[0, 1], [0, 1]])
y_score = np.array([[0, 1], [1, 0]])
# assert_raises(Exception, average_precision_score, y_true, y_score,
# average="macro")
# assert_raises(Exception, average_precision_score, y_true, y_score,
# average="weighted")
assert_almost_equal(average_precision_score(y_true, y_score, average="macro"), 0.75)
assert_almost_equal(average_precision_score(y_true, y_score, average="weighted"), 1.0)
assert_almost_equal(average_precision_score(y_true, y_score,
average="samples"), 0.625)
assert_almost_equal(average_precision_score(y_true, y_score,
Expand Down Expand Up @@ -980,26 +978,3 @@ def test_ranking_loss_ties_handling():
assert_almost_equal(label_ranking_loss([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 1)
assert_almost_equal(label_ranking_loss([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 1)


def test_precision_recall_curve_all_negatives():
"""
Test edge case for `precision_recall_curve`
if all the ground truth labels are negative.
Precision values should not be `nan`.
"""
y_true = [0 for _ in range(10)]
probas_pred = [np.random.rand() for _ in range(10)]
_, recall, _ = precision_recall_curve(y_true, probas_pred)
assert_not_equal(recall[0], np.nan)


def test_precision_recall_curve_all_positives():
"""
Test edge case for `precision_recall_curve`
if all the ground truth labels are positive.
"""
y_true = [1 for _ in range(10)]
probas_pred = [np.random.rand() for _ in range(10)]
precision, _, _ = precision_recall_curve(y_true, probas_pred)

assert_array_equal(precision, [1.0 for _ in range(len(precision))])

0 comments on commit 50fbd04

Please sign in to comment.