diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index f7b5fda459a09..e5cf5a29a8d52 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -46,6 +46,10 @@ Changelog correctly to maximize contrast with its background. :pr:`15936` by `Thomas Fan`_ and :user:`DizietAsahi`. +- |Fix| :func:`metrics.classification_report` does no longer ignore the + value of the ``zero_division`` keyword argument. :pr:`15879` + by :user:`Bibhash Chandra Mitra `. + :mod:`sklearn.utils` .................... diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 343e63b6c0ae9..cba7f2c2e8fc8 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1964,7 +1964,8 @@ class 2 1.00 0.67 0.80 3 # compute averages with specified averaging method avg_p, avg_r, avg_f1, _ = precision_recall_fscore_support( y_true, y_pred, labels=labels, - average=average, sample_weight=sample_weight) + average=average, sample_weight=sample_weight, + zero_division=zero_division) avg = [avg_p, avg_r, avg_f1, np.sum(s)] if output_dict: diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 197749d0ff2dd..947ca047438d8 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -20,7 +20,6 @@ from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_allclose -from sklearn.utils._testing import assert_warns from sklearn.utils._testing import assert_warns_div0 from sklearn.utils._testing import assert_no_warnings from sklearn.utils._testing import assert_warns_message @@ -154,6 +153,22 @@ def test_classification_report_dictionary_output(): assert type(expected_report['macro avg']['support']) == int +@pytest.mark.parametrize('zero_division', ["warn", 0, 1]) +def test_classification_report_zero_division_warning(zero_division): + y_true, y_pred = ["a", "b", "c"], ["a", "b", "d"] + with warnings.catch_warnings(record=True) as record: + classification_report( + y_true, y_pred, zero_division=zero_division, output_dict=True) + if zero_division == "warn": + assert len(record) > 1 + for item in record: + msg = ("Use `zero_division` parameter to control this " + "behavior.") + assert msg in str(item.message) + else: + assert not record + + def test_multilabel_accuracy_score_subset_accuracy(): # Dense label indicator matrix format y1 = np.array([[0, 1, 1], [1, 0, 1]])