Skip to content

Commit

Permalink
FIX mislabelling multiclass target when labels is provided in top_k_a…
Browse files Browse the repository at this point in the history
…ccuracy_score (#19721)
  • Loading branch information
joclement authored and glemaitre committed Apr 28, 2021
1 parent 0143fe4 commit ab21254
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
8 changes: 8 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ Changelog
- |Fix|: Fixed a bug in :class:`linear_model.LogisticRegression`: the
sample_weight object is not modified anymore. :pr:`19182` by
:user:`Yosuke KOBAYASHI <m7142yosuke>`.

:mod:`sklearn.metrics`
......................

- |Fix| :func:`metrics.top_k_accuracy_score` now supports multiclass
problems where only two classes appear in `y_true` and all the classes
are specified in `labels`.
:pr:`19721` by :user:`Joris Clement <flyingdutchman23>`.

:mod:`sklearn.model_selection`
..............................
Expand Down
4 changes: 3 additions & 1 deletion sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ def top_k_accuracy_score(y_true, y_score, *, k=2, normalize=True,
non-thresholded decision values (as returned by
:term:`decision_function` on some classifiers). The binary case expects
scores with shape (n_samples,) while the multiclass case expects scores
with shape (n_samples, n_classes). In the nulticlass case, the order of
with shape (n_samples, n_classes). In the multiclass case, the order of
the class scores must correspond to the order of ``labels``, if
provided, or else to the numerical or lexicographical order of the
labels in ``y_true``.
Expand Down Expand Up @@ -1646,6 +1646,8 @@ def top_k_accuracy_score(y_true, y_score, *, k=2, normalize=True,
y_true = check_array(y_true, ensure_2d=False, dtype=None)
y_true = column_or_1d(y_true)
y_type = type_of_target(y_true)
if y_type == "binary" and labels is not None and len(labels) > 2:
y_type = "multiclass"
y_score = check_array(y_score, ensure_2d=False)
y_score = column_or_1d(y_score) if y_type == 'binary' else y_score
check_consistent_length(y_true, y_score, sample_weight)
Expand Down
24 changes: 24 additions & 0 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,30 @@ def test_top_k_accuracy_score_binary(y_score, k, true_score):
assert score == score_acc == pytest.approx(true_score)


@pytest.mark.parametrize('y_true, true_score, labels', [
(np.array([0, 1, 1, 2]), 0.75, [0, 1, 2, 3]),
(np.array([0, 1, 1, 1]), 0.5, [0, 1, 2, 3]),
(np.array([1, 1, 1, 1]), 0.5, [0, 1, 2, 3]),
(np.array(['a', 'e', 'e', 'a']), 0.75, ['a', 'b', 'd', 'e']),
])
@pytest.mark.parametrize("labels_as_ndarray", [True, False])
def test_top_k_accuracy_score_multiclass_with_labels(
y_true, true_score, labels, labels_as_ndarray
):
"""Test when labels and y_score are multiclass."""
if labels_as_ndarray:
labels = np.asarray(labels)
y_score = np.array([
[0.4, 0.3, 0.2, 0.1],
[0.1, 0.3, 0.4, 0.2],
[0.4, 0.1, 0.2, 0.3],
[0.3, 0.2, 0.4, 0.1],
])

score = top_k_accuracy_score(y_true, y_score, k=2, labels=labels)
assert score == pytest.approx(true_score)


def test_top_k_accuracy_score_increasing():
# Make sure increasing k leads to a higher score
X, y = datasets.make_classification(n_classes=10, n_samples=1000,
Expand Down

0 comments on commit ab21254

Please sign in to comment.