Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
BUG use zero_division argument in classification_report (scikit-learn…
  • Loading branch information
Bibyutatsu authored and Pan Jan committed Mar 3, 2020
1 parent e693a5a commit 425d2dc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.22.rst
Expand Up @@ -46,6 +46,10 @@ Changelog
correctly to maximize contrast with its background. :pr:`15936` by
`Thomas Fan`_ and :user:`DizietAsahi`.

- |Fix| :func:`metrics.classification_report` does no longer ignore the
value of the ``zero_division`` keyword argument. :pr:`15879`
by :user:`Bibhash Chandra Mitra <Bibyutatsu>`.

:mod:`sklearn.utils`
....................

Expand Down
3 changes: 2 additions & 1 deletion sklearn/metrics/_classification.py
Expand Up @@ -1964,7 +1964,8 @@ class 2 1.00 0.67 0.80 3
# compute averages with specified averaging method
avg_p, avg_r, avg_f1, _ = precision_recall_fscore_support(
y_true, y_pred, labels=labels,
average=average, sample_weight=sample_weight)
average=average, sample_weight=sample_weight,
zero_division=zero_division)
avg = [avg_p, avg_r, avg_f1, np.sum(s)]

if output_dict:
Expand Down
17 changes: 16 additions & 1 deletion sklearn/metrics/tests/test_classification.py
Expand Up @@ -20,7 +20,6 @@
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_warns
from sklearn.utils._testing import assert_warns_div0
from sklearn.utils._testing import assert_no_warnings
from sklearn.utils._testing import assert_warns_message
Expand Down Expand Up @@ -154,6 +153,22 @@ def test_classification_report_dictionary_output():
assert type(expected_report['macro avg']['support']) == int


@pytest.mark.parametrize('zero_division', ["warn", 0, 1])
def test_classification_report_zero_division_warning(zero_division):
y_true, y_pred = ["a", "b", "c"], ["a", "b", "d"]
with warnings.catch_warnings(record=True) as record:
classification_report(
y_true, y_pred, zero_division=zero_division, output_dict=True)
if zero_division == "warn":
assert len(record) > 1
for item in record:
msg = ("Use `zero_division` parameter to control this "
"behavior.")
assert msg in str(item.message)
else:
assert not record


def test_multilabel_accuracy_score_subset_accuracy():
# Dense label indicator matrix format
y1 = np.array([[0, 1, 1], [1, 0, 1]])
Expand Down

0 comments on commit 425d2dc

Please sign in to comment.