Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix format of values in confusion matrix plot. #16159

Merged
merged 107 commits into from Mar 1, 2020
Merged
Show file tree
Hide file tree
Changes from 104 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
f3ee74f
"Changes to the values_format for plotting the confusion_matrix. This
Jan 17, 2020
5fb226b
"Fixes for confusion_matrix values_format."
Jan 17, 2020
cb0083e
"<= value for 0"
Jan 17, 2020
551efa8
"Changed back to == 0"
Jan 17, 2020
e253373
"Import math log10 for linting error"
Jan 17, 2020
02995c5
"Linting and hard-coded value checks now working"
Jan 17, 2020
e8019fb
"Use log(10) instead of 1e7 (python version error?)"
Jan 17, 2020
8d6f8a3
"If cmd.type == f test included"
Jan 17, 2020
3b28b17
"Added type test in the correct place"
Jan 17, 2020
6208f85
"Different IF statement, see if it fixes azure pipeline errors"
Jan 18, 2020
c32cc01
"Changed order and use np.log10 ofcourse"
Jan 18, 2020
18137c1
"Removed brackets"
Jan 18, 2020
409c251
"Added testing in test_plot_confusion_matrix (values_format)"
Jan 20, 2020
3e7168e
"Wrong value format"
Jan 20, 2020
23b3a20
"Changed the test_plot_confusion_matrix."
Jan 20, 2020
d31518a
"Add test plot confusion matrix"
Jan 20, 2020
5e9b925
"Added extra test cases for test_plot_confusion_matrix"
Jan 20, 2020
46fef3d
"Fixed testing values"
Jan 20, 2020
0c9cb36
"Added argument to the test function, duh"
Jan 20, 2020
d21a7ff
"New tests to also test if for values > 1e7"
Jan 20, 2020
93cdb7a
"New test version, made sure the values are now loaded into the model
Jan 21, 2020
af47489
"Improved test file"
Jan 21, 2020
c404c77
"New changes"
Jan 21, 2020
1f766ca
"Fixed the ConfusionMatrixDisplay class, where values where not
Jan 21, 2020
690736f
"Fixing of loop and test"
Jan 22, 2020
89e6d1a
"Changed if statement in the for loop and new test function"
Jan 22, 2020
ff0cd1a
"Changes to test function and added the if condition in the for-loop
Jan 22, 2020
745725e
"Removed the assert equal size statement."
Jan 22, 2020
7b007d3
"Forgot to add the improved ConfusionMatrixDisplay"
Jan 22, 2020
bbd9037
"Cleaned test file and now all inside the for-loop."
Jan 22, 2020
0baf844
"Declare variable before the if-statement check"
Jan 22, 2020
f672d22
"Made the if-statement cleaner."
Jan 22, 2020
4b791c0
"Updated test file, and added the statement where if the length of the
Jan 24, 2020
f01baf6
"Changed back to values larger than 1e7"
Jan 24, 2020
f8ab0ff
"Test if the failed tests on azure pipelines are because of my edits?"
Jan 24, 2020
8c11333
"Improved test file, and included expected values from the test."
Jan 27, 2020
c72ec0f
"Improved test of formats (explicitly check strings)"
Jan 30, 2020
0cf3a7a
"Modified the test function for test_plot_confusion_matrix"
Jan 30, 2020
f419191
"Improved test file, now check for literal string array equal."
Jan 31, 2020
a94ae2a
"Use ravel instead of flatten"
Jan 31, 2020
e5925b4
"Removed array and used list only with bare assert statement."
Feb 3, 2020
651b4c6
"Cleaned up the test file, no parameterize and cleaner code using
Feb 6, 2020
d4b28ff
"Added pyplot to the function parameter so it passes tests."
Feb 6, 2020
dd5875f
"Use values_format in parameter test, to see if it resolves
Feb 7, 2020
f100cb6
"Shorten the test file, rename of some variables, list comprehension"
Feb 10, 2020
0b3e021
"Changed to 1e6"
Feb 11, 2020
bc51fd3
"Now picks the shortest length format."
Feb 11, 2020
ae40567
"Small changes in confusion plot values"
Feb 12, 2020
960b96e
Merge branch 'confusion_fix'
Feb 12, 2020
a3d7470
"Changes to the values_format for plotting the confusion_matrix. This
Jan 17, 2020
db163e2
"Fixes for confusion_matrix values_format."
Jan 17, 2020
8029558
"<= value for 0"
Jan 17, 2020
1158bc9
"Changed back to == 0"
Jan 17, 2020
ccf94db
"Import math log10 for linting error"
Jan 17, 2020
44bc54d
"Linting and hard-coded value checks now working"
Jan 17, 2020
3896c9c
"Use log(10) instead of 1e7 (python version error?)"
Jan 17, 2020
0cc1ce0
"If cmd.type == f test included"
Jan 17, 2020
941ff90
"Added type test in the correct place"
Jan 17, 2020
553d550
"Different IF statement, see if it fixes azure pipeline errors"
Jan 18, 2020
bf74685
"Changed order and use np.log10 ofcourse"
Jan 18, 2020
a92f36b
"Removed brackets"
Jan 18, 2020
8f3aae7
"Added testing in test_plot_confusion_matrix (values_format)"
Jan 20, 2020
c9fe195
"Wrong value format"
Jan 20, 2020
868b99e
"Changed the test_plot_confusion_matrix."
Jan 20, 2020
255992c
"Add test plot confusion matrix"
Jan 20, 2020
69e0068
"Added extra test cases for test_plot_confusion_matrix"
Jan 20, 2020
177df61
"Fixed testing values"
Jan 20, 2020
594efe0
"Added argument to the test function, duh"
Jan 20, 2020
5f250fe
"New tests to also test if for values > 1e7"
Jan 20, 2020
2cc9ef3
"New test version, made sure the values are now loaded into the model
Jan 21, 2020
d0de1f3
"Improved test file"
Jan 21, 2020
1c6244c
"New changes"
Jan 21, 2020
81054c2
"Fixed the ConfusionMatrixDisplay class, where values where not
Jan 21, 2020
2106a4d
"Fixing of loop and test"
Jan 22, 2020
37abea4
"Changed if statement in the for loop and new test function"
Jan 22, 2020
3ec1ffa
"Changes to test function and added the if condition in the for-loop
Jan 22, 2020
e14c4a5
"Removed the assert equal size statement."
Jan 22, 2020
eb724db
"Forgot to add the improved ConfusionMatrixDisplay"
Jan 22, 2020
932a077
"Cleaned test file and now all inside the for-loop."
Jan 22, 2020
6029363
"Declare variable before the if-statement check"
Jan 22, 2020
67cf462
"Made the if-statement cleaner."
Jan 22, 2020
60253ea
"Updated test file, and added the statement where if the length of the
Jan 24, 2020
769a2f6
"Changed back to values larger than 1e7"
Jan 24, 2020
613cd80
"Test if the failed tests on azure pipelines are because of my edits?"
Jan 24, 2020
4531040
"Improved test file, and included expected values from the test."
Jan 27, 2020
52818d8
"Improved test of formats (explicitly check strings)"
Jan 30, 2020
32e64e8
"Modified the test function for test_plot_confusion_matrix"
Jan 30, 2020
3e60478
"Improved test file, now check for literal string array equal."
Jan 31, 2020
3aca5b4
"Use ravel instead of flatten"
Jan 31, 2020
2a857c8
"Removed array and used list only with bare assert statement."
Feb 3, 2020
f28914b
"Cleaned up the test file, no parameterize and cleaner code using
Feb 6, 2020
4465f5a
"Added pyplot to the function parameter so it passes tests."
Feb 6, 2020
c334548
"Use values_format in parameter test, to see if it resolves
Feb 7, 2020
dc35958
"Shorten the test file, rename of some variables, list comprehension"
Feb 10, 2020
080bb16
"Changed to 1e6"
Feb 11, 2020
7ac8e02
"Now picks the shortest length format."
Feb 11, 2020
d4ca82a
"Small changes in confusion plot values"
Feb 12, 2020
b80d38c
Merge branch 'confusion_fix' of https://github.com/Rick-Mackenbach/sc…
Feb 12, 2020
74f7244
"Fix of wrong merge"
Feb 12, 2020
a71357d
Merge branch 'confusion_fix' of https://github.com/Rick-Mackenbach/sc…
Feb 12, 2020
6c37b11
"Small change in test file"
Feb 17, 2020
19319e6
"Updated the logic for interpretability. Also small changes to test"
Feb 21, 2020
de4e76e
"Do the formatting of values in the if-statement, so shorter code :)"
Feb 24, 2020
666246d
"Changed test file to test for 'd' values"
Feb 24, 2020
7895260
"Added changes to the v0.23 what's news"
Feb 25, 2020
e920a7b
"Modified doc/whats_new/v0.23.rst"
Feb 25, 2020
3365587
Merge remote-tracking branch 'upstream/master' into pr/16159
thomasjpfan Feb 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 17 additions & 9 deletions sklearn/metrics/_plot/confusion_matrix.py
Expand Up @@ -61,7 +61,7 @@ def plot(self, include_values=True, cmap='viridis',

values_format : str, default=None
Format specification for values in confusion matrix. If `None`,
the format specification is '.2g'.
the format specification is 'd' or '.2g' whichever is shorter.

ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
Expand All @@ -83,22 +83,30 @@ def plot(self, include_values=True, cmap='viridis',
n_classes = cm.shape[0]
self.im_ = ax.imshow(cm, interpolation='nearest', cmap=cmap)
self.text_ = None

cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(256)

if include_values:
self.text_ = np.empty_like(cm, dtype=object)
if values_format is None:
values_format = '.2g'

# print text with appropriate color depending on background
thresh = (cm.max() + cm.min()) / 2.0

for i, j in product(range(n_classes), range(n_classes)):
color = cmap_max if cm[i, j] < thresh else cmap_min
self.text_[i, j] = ax.text(j, i,
format(cm[i, j], values_format),
ha="center", va="center",
color=color)

if values_format is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think trying .2g and then trying d if applicable, is slightly clearer

This suggestion is to address #16159 (comment)

if values_format is None:
    text_cm = format(cm[i, j], '.2g')
    if cm.dtype.kind != 'f':
        text_d = format(cm[i, j], 'd')
        if len(text_d) < len(text_cm):
            text_cm = text_d
else:
    text_cm = format(cm[i, j], values_format)

The test would need to updated to reflect this behavior.

text_cm = format(cm[i, j], '.2g')
if cm.dtype.kind != 'f':
text_d = format(cm[i, j], 'd')
if len(text_d) < len(text_cm):
text_cm = text_d
else:
text_cm = format(cm[i, j], values_format)

self.text_[i, j] = ax.text(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I was not more clear. The current PR may call format three times. I am trying to push for calling it at most twice:

if values_format is None:
    text_cm = format(cm[i, j], '.2g')
    if cm.dtype.kind != 'f':
        text_d = format(cm[i, j], 'd')
        if len(text_d) < len(text_cm):
            text_cm = text_d
else:
    text_cm = format(cm[i, j], values_format)

self.text_[i, j] = ax.text(j, i, text_cm, ...)

j, i, text_cm,
ha="center", va="center",
color=color)

fig.colorbar(self.im_, ax=ax)
ax.set(xticks=np.arange(n_classes),
Expand Down Expand Up @@ -164,7 +172,7 @@ def plot_confusion_matrix(estimator, X, y_true, labels=None,

values_format : str, default=None
Format specification for values in confusion matrix. If `None`,
the format specification is '.2g'.
the format specification is 'd' or '.2g' whichever is shorter.

cmap : str or matplotlib Colormap, default='viridis'
Colormap recognized by matplotlib.
Expand Down
20 changes: 18 additions & 2 deletions sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py
Expand Up @@ -21,6 +21,7 @@
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
"matplotlib.*")


@pytest.fixture(scope="module")
def n_classes():
return 5
Expand Down Expand Up @@ -226,8 +227,6 @@ def test_confusion_matrix_contrast(pyplot):
assert_allclose(disp.text_[1, 1].get_color(), min_color)




@pytest.mark.parametrize(
"clf", [LogisticRegression(),
make_pipeline(StandardScaler(), LogisticRegression()),
Expand Down Expand Up @@ -264,3 +263,20 @@ def test_confusion_matrix_text_format(pyplot, data, y_pred, n_classes,
text_text = np.array([
t.get_text() for t in disp.text_.ravel()])
assert_array_equal(expected_text, text_text)


def test_confusion_matrix_standard_format(pyplot):
cm = np.array([[10000000, 0], [123456, 12345678]])
plotted_text = ConfusionMatrixDisplay(cm, [False, True]).plot().text_
# Values should be shown as whole numbers 'd',
# except the first number which should be shown as 1e+07 (longer length)
# and the last number will be showns as 1.2e+07 (longer length)
test = [t.get_text() for t in plotted_text.ravel()]
assert test == ['1e+07', '0', '123456', '1.2e+07']

cm = np.array([[0.1, 10], [100, 0.525]])
plotted_text = ConfusionMatrixDisplay(cm, [False, True]).plot().text_
# Values should now formatted as '.2g', since there's a float in
# Values are have two dec places max, (e.g 100 becomes 1e+02)
test = [t.get_text() for t in plotted_text.ravel()]
assert test == ['0.1', '10', '1e+02', '0.53']