Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] ENH Adds plot_confusion matrix #15083

Merged
merged 32 commits into from Nov 14, 2019

Conversation

@thomasjpfan
Copy link
Member

thomasjpfan commented Sep 24, 2019

Reference Issues/PRs

Related to #7116

What does this implement/fix? Explain your changes.

Adds plotting function for the confusion matrix.

thomasjpfan added 2 commits Aug 22, 2019
@amueller

This comment has been minimized.

Copy link
Member

amueller commented Sep 24, 2019

lint ;)

thomasjpfan added 2 commits Sep 25, 2019
…trix_v2
fmt = '.2f' if self.normalize else 'd'
thresh = cm.max() / 2.
for i, j in product(range(cm.shape[0]), range(cm.shape[1])):
color = "white" if cm[i, j] < thresh else "black"

This comment has been minimized.

Copy link
@amueller

amueller Sep 25, 2019

Member

I think that's weird as it doesn't depend on the colormap.
Here's how I usually do it:
https://github.com/amueller/mglearn/blob/master/mglearn/tools.py#L76

without depending on the colormap there's no way this works, right? because someone could use greys and greys_r and they clearly need the opposite colors.

This comment has been minimized.

Copy link
@amueller

amueller Sep 25, 2019

Member

I think it should be pcolormesh not pcolor, though.

This comment has been minimized.

Copy link
@amueller

amueller Sep 25, 2019

Member

also: shouldn't this go in a separate helper function? It's probably not the only time we want to show a heatmap (grid search will need this as well). The main question then is if that will be public or not :-/

This comment has been minimized.

Copy link
@thomasjpfan

This comment has been minimized.

Copy link
@amueller

amueller Sep 25, 2019

Member

Looks reasonable.
Can you maybe add a test? Like calling ConfusionMatrixDisplay with np.eye(2) and plt.cm.greys and check that the text colors are black white black white and with plt.cm.greys_r and check that the text colors are white black white black?

titles_options = [("Confusion matrix, without normalization", False),
("Normalized confusion matrix", True)]
for title, normalize in titles_options:
fig, ax = plt.subplots()

This comment has been minimized.

Copy link
@amueller

amueller Sep 25, 2019

Member

Why? There's no reason to pass ax, right?
For setting the title you could just do plt.gca().set_title(title).
Or do you knot like using the state like that?

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Sep 25, 2019

Author Member

Updated with not having to define ax and passing it in and using the axes stored in the Display object.

This comment has been minimized.

Copy link
@amueller

amueller Sep 25, 2019

Member

ah, even better.

thomasjpfan added 2 commits Sep 25, 2019
@amueller

This comment has been minimized.

Copy link
Member

amueller commented Sep 25, 2019

thomasjpfan added 6 commits Oct 10, 2019
…trix_v2
select a subset of labels. If `None` is given, those that appear at
least once in `y_true` or `y_pred` are used in sorted order.
target_names : array-like of shape (n_classes,), default=None

This comment has been minimized.

Copy link
@jnothman

jnothman Oct 10, 2019

Member

Don't call this target names. That implies multiple targets. Rather, display_labels will be sufficient?

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Oct 10, 2019

Author Member

Hmmm, do you think classes or class_names would be better?

Includes values in confusion matrix.
normalize : bool, default=False
Normalizes confusion matrix.

This comment has been minimized.

Copy link
@jnothman

jnothman Oct 10, 2019

Member

The user might want to normalise over either axis, or altogether.

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Oct 10, 2019

Author Member

Four options? I guess we can do 'row', 'column', 'all', None?

This comment has been minimized.

Copy link
@jnothman

jnothman Oct 11, 2019

Member

I'm okay to not provide this flexibility, too. Another way to specify it is "all", "recall", "precision", None.

This comment has been minimized.

Copy link
@glemaitre

glemaitre Oct 31, 2019

Contributor

Would it make sense to use "truth" and "predicted" instead of "recall" and "precision"?

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Nov 6, 2019

Author Member

Updated PR to use 'truth' and 'predicted'. Almost feels like this should be in confusion_matrix itself.

@thomasjpfan

This comment has been minimized.

Copy link
Member Author

thomasjpfan commented Oct 23, 2019

Updated with using display_labels and using the dtype of confusion matrix to infer if the matrix is normalized.

@thomasjpfan thomasjpfan added this to the 0.22 milestone Oct 25, 2019
@thomasjpfan

This comment has been minimized.

Copy link
Member Author

thomasjpfan commented Oct 25, 2019

Copy link
Contributor

NicolasHug left a comment

cmap='viridis', ax=None):
"""Plot Confusion Matrix.
Read more in the :ref:`User Guide <visualizations>`.

This comment has been minimized.

Copy link
@NicolasHug
thomasjpfan added 4 commits Oct 28, 2019
…trix_v2
WIP
WIP
@adrinjalali adrinjalali added this to In progress in Meeting Issues via automation Oct 29, 2019
@jnothman

This comment has been minimized.

Copy link
Member

jnothman commented Nov 6, 2019

Copy link
Contributor

glemaitre left a comment

With the latest changes, LGTM

Rotation of xtick labels.
values_format : str, default=None
Format specification for values in confusion matrix. If None,

This comment has been minimized.

Copy link
@glemaitre

glemaitre Nov 7, 2019

Contributor

Jus this nitpick

@glemaitre

This comment has been minimized.

Copy link
Contributor

glemaitre commented Nov 7, 2019

@glemaitre yes our classifiers work with list of strings, but out simple example using load_iris returns the integer encoding and not the strings. A user using load_iris will need to pass in the display_labels=iris.target_names to get the expected labeling.

So it seems that we need them in case we want to overwrite it. So we can keep it has it is until by default we don't need to specify it.

Copy link
Contributor

NicolasHug left a comment

Thanks @thomasjpfan , mostly looks good.

I'm slightly concerned about testing time and coupling though

Includes values in confusion matrix.
normalize : {'true', 'pred', 'all'}, default=None
Normalizes confusion matrix over the true, predicited conditions or

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Nov 7, 2019

Contributor

Just a suggestion

Suggested change
Normalizes confusion matrix over the true, predicited conditions or
Normalizes confusion matrix over the true (rows), predicited conditions (columns) or
labels=labels)

if normalize == 'true':
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Nov 7, 2019

Contributor

I think we should not convert to float (see other msg about high coupling)


cm = self.confusion_matrix
n_classes = cm.shape[0]
normalized = np.issubdtype(cm.dtype, np.float_)

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Nov 7, 2019

Contributor

This logic involves a strong coupling between

confusion_matrix -> plot_confusion_matrix -> ConfusionMatrixDisplay

and might cause silent bugs in the future.

I would rather pass a is_normalized parameter (or remove, see below)

if include_values:
self.text_ = np.empty_like(cm, dtype=object)
if values_format is None:
values_format = '.2f' if normalized else 'd'

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Nov 7, 2019

Contributor

I think that the .2g option is what we need, and you wouldn't have to use the normalized variable anymore:

In [15]: "{:.2g} -- {:.2g} -- {:.2g}".format(2, 2.0000, 2.23425)                                                                        
Out[15]: '2 -- 2 -- 2.2'
@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("with_labels", [True, False])
@pytest.mark.parametrize("cmap", ['viridis', 'plasma'])
@pytest.mark.parametrize("with_custom_axes", [True, False])
@pytest.mark.parametrize("with_display_labels", [True, False])
@pytest.mark.parametrize("include_values", [True, False])
Comment on lines 54 to 60

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Nov 7, 2019

Contributor

Do we really need each of these combinations to be tested independently?

It seems to me that most of the checks in this test could be independent tests functions. Parametrization is nice but seems way overkill here.

This will test 256 instances, and it take about 10s on my machine which is not negligible considering small increment in testing time really add up over time.

@thomasjpfan

This comment has been minimized.

Copy link
Member Author

thomasjpfan commented Nov 7, 2019

To be consistent with the plot_roc_curve, how do you feel about names or display_names instead of names?

thomasjpfan added 2 commits Nov 7, 2019
@thomasjpfan

This comment has been minimized.

Copy link
Member Author

thomasjpfan commented Nov 7, 2019

Ah display_labels is okay, since this is this a different context. Updated PR to reduce the number of tests and to address comments.

Interpretability / Plotting / Interactive dev automation moved this from Review in progress to Reviewer approved Nov 8, 2019
Copy link
Contributor

NicolasHug left a comment

last nits

assert disp.ax_ == ax

if normalize == 'true':
cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Nov 8, 2019

Contributor

you dont need the conversion anymore right?

@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
@pytest.mark.parametrize("with_labels", [True, False])
@pytest.mark.parametrize("with_display_labels", [True, False])
@pytest.mark.parametrize("include_values", [True, False])
Comment on lines 54 to 57

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Nov 8, 2019

Contributor

The main reason I'm not a fan of this is that such parametrization suggests that all these 4 parameters are intertwined and are dependent one to another, but in reality this isn't the case

I think we could still remove some parametrizations, but that's fine

create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
attributes.
Read more in the :ref:`User Guide <confusion_matrix>`.

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Nov 8, 2019

Contributor

Shouldn't this link to the visualization UG?

Meeting Issues automation moved this from Review in progress to Reviewer approved Nov 8, 2019
include_values : bool, default=True
Includes values in confusion matrix.
normalize : {'true', 'pred', 'all'}, default=None

This comment has been minimized.

Copy link
@qinhanmin2014

qinhanmin2014 Nov 13, 2019

Member

If we decide to support normalize here, perhaps we should also support it in confusion_matrix (See #14478).
And I can't understand why we need normalize="all".

This comment has been minimized.

Copy link
@glemaitre

glemaitre Nov 14, 2019

Contributor

Good remark. normalize='all' will normalize by the total support.

This comment has been minimized.

Copy link
@glemaitre

glemaitre Nov 14, 2019

Contributor

However, I would suggest to add it to another PR.

@glemaitre glemaitre self-assigned this Nov 14, 2019
@glemaitre

This comment has been minimized.

Copy link
Contributor

glemaitre commented Nov 14, 2019

I made a push to solve the conflicts

@glemaitre

This comment has been minimized.

Copy link
Contributor

glemaitre commented Nov 14, 2019

and I added a similar test to the other plotting for pipeline.
@qinhanmin2014 feel free to merge when it is green

@glemaitre glemaitre merged commit e650a20 into scikit-learn:master Nov 14, 2019
21 checks passed
21 checks passed
LGTM analysis: C/C++ No code changes detected
Details
LGTM analysis: JavaScript No code changes detected
Details
LGTM analysis: Python No new or fixed alerts
Details
ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: doc Your tests passed on CircleCI!
Details
ci/circleci: doc artifact Link to 0/doc/_changed.html
Details
ci/circleci: doc-min-dependencies Your tests passed on CircleCI!
Details
ci/circleci: lint Your tests passed on CircleCI!
Details
codecov/patch 100% of diff hit (target 97.21%)
Details
codecov/project 97.22% (+<.01%) compared to 25e72d3
Details
scikit-learn.scikit-learn Build #20191114.23 succeeded
Details
scikit-learn.scikit-learn (Linting) Linting succeeded
Details
scikit-learn.scikit-learn (Linux py35_conda_openblas) Linux py35_conda_openblas succeeded
Details
scikit-learn.scikit-learn (Linux py35_ubuntu_atlas) Linux py35_ubuntu_atlas succeeded
Details
scikit-learn.scikit-learn (Linux pylatest_pip_openblas_pandas) Linux pylatest_pip_openblas_pandas succeeded
Details
scikit-learn.scikit-learn (Linux32 py35_ubuntu_atlas_32bit) Linux32 py35_ubuntu_atlas_32bit succeeded
Details
scikit-learn.scikit-learn (Linux_Runs pylatest_conda_mkl) Linux_Runs pylatest_conda_mkl succeeded
Details
scikit-learn.scikit-learn (Windows py35_pip_openblas_32bit) Windows py35_pip_openblas_32bit succeeded
Details
scikit-learn.scikit-learn (Windows py37_conda_mkl) Windows py37_conda_mkl succeeded
Details
scikit-learn.scikit-learn (macOS pylatest_conda_mkl) macOS pylatest_conda_mkl succeeded
Details
scikit-learn.scikit-learn (macOS pylatest_conda_mkl_no_openmp) macOS pylatest_conda_mkl_no_openmp succeeded
Details
Interpretability / Plotting / Interactive dev automation moved this from Reviewer approved to Done Nov 14, 2019
Meeting Issues automation moved this from Reviewer approved to Done Nov 14, 2019
@glemaitre

This comment has been minimized.

Copy link
Contributor

glemaitre commented Nov 14, 2019

OK merging this one. I will open a new PR to address the problem raised by @qinhanmin2014 in #15083 (comment)

adrinjalali added a commit to adrinjalali/scikit-learn that referenced this pull request Nov 18, 2019
adrinjalali added a commit to adrinjalali/scikit-learn that referenced this pull request Nov 18, 2019
adrinjalali added a commit that referenced this pull request Nov 19, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.