Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corrected the missing zero_division argument #15879

Merged
merged 6 commits into from Dec 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.22.rst
Expand Up @@ -46,6 +46,10 @@ Changelog
correctly to maximize contrast with its background. :pr:`15936` by
`Thomas Fan`_ and :user:`DizietAsahi`.

- |Fix| :func:`metrics.classification_report` does no longer ignore the
value of the ``zero_division`` keyword argument. :pr:`15879`
by :user:`Bibhash Chandra Mitra <Bibyutatsu>`.

:mod:`sklearn.utils`
....................

Expand Down
3 changes: 2 additions & 1 deletion sklearn/metrics/_classification.py
Expand Up @@ -1964,7 +1964,8 @@ class 2 1.00 0.67 0.80 3
# compute averages with specified averaging method
avg_p, avg_r, avg_f1, _ = precision_recall_fscore_support(
y_true, y_pred, labels=labels,
average=average, sample_weight=sample_weight)
average=average, sample_weight=sample_weight,
zero_division=zero_division)
avg = [avg_p, avg_r, avg_f1, np.sum(s)]

if output_dict:
Expand Down
17 changes: 16 additions & 1 deletion sklearn/metrics/tests/test_classification.py
Expand Up @@ -20,7 +20,6 @@
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_warns
from sklearn.utils._testing import assert_warns_div0
from sklearn.utils._testing import assert_no_warnings
from sklearn.utils._testing import assert_warns_message
Expand Down Expand Up @@ -154,6 +153,22 @@ def test_classification_report_dictionary_output():
assert type(expected_report['macro avg']['support']) == int


@pytest.mark.parametrize('zero_division', ["warn", 0, 1])
def test_classification_report_zero_division_warning(zero_division):
y_true, y_pred = ["a", "b", "c"], ["a", "b", "d"]
with warnings.catch_warnings(record=True) as record:
classification_report(
y_true, y_pred, zero_division=zero_division, output_dict=True)
if zero_division == "warn":
assert len(record) > 1
for item in record:
msg = ("Use `zero_division` parameter to control this "
"behavior.")
assert msg in str(item.message)
else:
assert not record


def test_multilabel_accuracy_score_subset_accuracy():
# Dense label indicator matrix format
y1 = np.array([[0, 1, 1], [1, 0, 1]])
Expand Down