Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG fix behaviour in confusion_matrix with with empty array-like as input #16442

Merged
merged 12 commits into from Feb 23, 2020
4 changes: 4 additions & 0 deletions doc/whats_new/v0.23.rst
Expand Up @@ -202,6 +202,10 @@ Changelog
- |Fix| Fixed a bug in :func:`metrics.mutual_info_score` where negative
scores could be returned. :pr:`16362` by `Thomas Fan`_.

- |Fix| Fixed a bug in :func:`metrics.confusion_matrix` that would raise
an error when `y_true` and `y_pred` were length zero and `labels` was
not `None`. :pr:`16442` by `Kyle Parsons <parsons-kyle-89>`

:mod:`sklearn.model_selection`
..............................

Expand Down
3 changes: 3 additions & 0 deletions sklearn/metrics/_classification.py
Expand Up @@ -274,6 +274,9 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None,

if labels is None:
labels = unique_labels(y_true, y_pred)
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
elif y_true.size == 0:
n_labels = len(labels)
return np.zeros((n_labels, n_labels), dtype=np.int)
else:
labels = np.asarray(labels)
if np.all([l not in y_true for l in labels]):
Expand Down
7 changes: 7 additions & 0 deletions sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py
Expand Up @@ -225,7 +225,14 @@ def test_confusion_matrix_contrast(pyplot):
assert_allclose(disp.text_[1, 0].get_color(), max_color)
assert_allclose(disp.text_[1, 1].get_color(), min_color)

# Non-regression test for #16442
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
cm = np.array([[0, 0], [0, 0]])
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])

disp.plot(cmap=pyplot.cm.Blues)
min_color = pyplot.cm.Blues(0)
for text in disp.text_.ravel():
assert_allclose(text.get_color(), min_color)


@pytest.mark.parametrize(
Expand Down
12 changes: 12 additions & 0 deletions sklearn/metrics/tests/test_classification.py
Expand Up @@ -914,6 +914,18 @@ def test_confusion_matrix_multiclass_subset_labels():
labels=[extra_label, extra_label + 1])


@pytest.mark.parametrize(
'labels', (None, [0, 1], [0, 1, 2]),
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
ids=['None', 'binary', 'multiclass']
)
def test_confusion_matrix_on_zero_length_input(labels):
labels = [] if not labels else labels
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
n_classes = len(labels)
expected = np.zeros((n_classes, n_classes), dtype=np.int)
cm = confusion_matrix([], [], labels)
assert_array_equal(cm, expected)


def test_confusion_matrix_dtype():
y = [0, 1, 1]
weight = np.ones(len(y))
Expand Down