Skip to content

Commit

Permalink
Fix _check_targets where y_true and y_pred are both binary
Browse files Browse the repository at this point in the history
but the union of them is multiclass.
  • Loading branch information
lesteve committed Feb 16, 2017
1 parent 80c1bf1 commit e01f779
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
4 changes: 4 additions & 0 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def _check_targets(y_true, y_pred):
if y_type in ["binary", "multiclass"]:
y_true = column_or_1d(y_true)
y_pred = column_or_1d(y_pred)
if y_type == "binary":
unique_values = np.union1d(y_true, y_pred)
if len(unique_values) > 2:
y_type = "multiclass"

if y_type.startswith('multilabel'):
y_true = csr_matrix(y_true)
Expand Down
13 changes: 10 additions & 3 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def test_matthews_corrcoef():
y_true_inv = ["b" if i == "a" else "a" for i in y_true]

assert_almost_equal(matthews_corrcoef(y_true, y_true_inv), -1)
y_true_inv2 = label_binarize(y_true, ["a", "b"]) * -1
y_true_inv2 = label_binarize(y_true, ["a", "b"])
y_true_inv2 = np.where(y_true_inv2, 'a', 'b')
assert_almost_equal(matthews_corrcoef(y_true, y_true_inv2), -1)

# For the zero vector case, the corrcoef cannot be calculated and should
Expand All @@ -379,8 +380,7 @@ def test_matthews_corrcoef():

# And also for any other vector with 0 variance
mcc = assert_warns_message(RuntimeWarning, 'invalid value encountered',
matthews_corrcoef, y_true,
rng.randint(-100, 100) * np.ones(20, dtype=int))
matthews_corrcoef, y_true, ['a'] * len(y_true))

# But will output 0
assert_almost_equal(mcc, 0.)
Expand Down Expand Up @@ -1267,6 +1267,13 @@ def test__check_targets():
assert_raise_message(ValueError, msg, _check_targets, y1, y2)


def test__check_targets_multiclass_with_both_y_true_and_y_pred_binary():
# https://github.com/scikit-learn/scikit-learn/issues/8098
y_true = [0, 1]
y_pred = [0, -1]
assert_equal(_check_targets(y_true, y_pred)[0], 'multiclass')


def test_hinge_loss_binary():
y_true = np.array([-1, 1, 1, -1])
pred_decision = np.array([-8.5, 0.5, 1.5, -0.3])
Expand Down

0 comments on commit e01f779

Please sign in to comment.