Skip to content

Commit

Permalink
ENH Adds plot_confusion matrix (#15083)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored and glemaitre committed Nov 14, 2019
1 parent 25e72d3 commit e650a20
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 80 deletions.
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1082,13 +1082,15 @@ See the :ref:`visualizations` section of the user guide for further details.
:toctree: generated/
:template: function.rst

metrics.plot_confusion_matrix
metrics.plot_precision_recall_curve
metrics.plot_roc_curve

.. autosummary::
:toctree: generated/
:template: class.rst

metrics.ConfusionMatrixDisplay
metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay

Expand Down
6 changes: 4 additions & 2 deletions doc/modules/model_evaluation.rst
Expand Up @@ -573,8 +573,10 @@ predicted to be in group :math:`j`. Here is an example::
[0, 0, 1],
[1, 0, 2]])

Here is a visual representation of such a confusion matrix (this figure comes
from the :ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py` example):
:func:`plot_confusion_matrix` can be used to visually represent a confusion
matrix as shown in the
:ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py`
example, which creates the following figure:

.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_confusion_matrix_001.png
:target: ../auto_examples/model_selection/plot_confusion_matrix.html
Expand Down
2 changes: 2 additions & 0 deletions doc/visualizations.rst
Expand Up @@ -72,6 +72,7 @@ Functions
.. autosummary::

inspection.plot_partial_dependence
metrics.plot_confusion_matrix
metrics.plot_precision_recall_curve
metrics.plot_roc_curve

Expand All @@ -84,5 +85,6 @@ Display Objects
.. autosummary::

inspection.PartialDependenceDisplay
metrics.ConfusionMatrixDisplay
metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay
27 changes: 14 additions & 13 deletions examples/classification/plot_digits_classification.py
Expand Up @@ -31,12 +31,12 @@
# matplotlib.pyplot.imread. Note that each image must have the same size. For these
# images, we know which digit they represent: it is given in the 'target' of
# the dataset.
_, axes = plt.subplots(2, 4)
images_and_labels = list(zip(digits.images, digits.target))
for index, (image, label) in enumerate(images_and_labels[:4]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Training: %i' % label)
for ax, (image, label) in zip(axes[0, :], images_and_labels[:4]):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
ax.set_title('Training: %i' % label)

# To apply a classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
Expand All @@ -56,15 +56,16 @@
# Now predict the value of the digit on the second half:
predicted = classifier.predict(X_test)

images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
for ax, (image, prediction) in zip(axes[1, :], images_and_predictions[:4]):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
ax.set_title('Prediction: %i' % prediction)

print("Classification report for classifier %s:\n%s\n"
% (classifier, metrics.classification_report(y_test, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, predicted))

images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
plt.subplot(2, 4, index + 5)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Prediction: %i' % prediction)
disp = metrics.plot_confusion_matrix(classifier, X_test, y_test)
disp.figure_.suptitle("Confusion Matrix")
print("Confusion matrix:\n%s" % disp.confusion_matrix)

plt.show()
78 changes: 13 additions & 65 deletions examples/model_selection/plot_confusion_matrix.py
Expand Up @@ -31,8 +31,7 @@

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import plot_confusion_matrix

# import some data to play with
iris = datasets.load_iris()
Expand All @@ -45,72 +44,21 @@

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)


def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Only use the labels that appear in the data
classes = classes[unique_labels(y_true, y_pred)]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')

print(cm)

fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax

classifier = svm.SVC(kernel='linear', C=0.01).fit(X_train, y_train)

np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plot_confusion_matrix(y_test, y_pred, classes=class_names,
title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True,
title='Normalized confusion matrix')
titles_options = [("Confusion matrix, without normalization", None),
("Normalized confusion matrix", 'true')]
for title, normalize in titles_options:
disp = plot_confusion_matrix(classifier, X_test, y_test,
display_labels=class_names,
cmap=plt.cm.Blues,
normalize=normalize)
disp.ax_.set_title(title)

print(title)
print(disp.confusion_matrix)

plt.show()
5 changes: 5 additions & 0 deletions sklearn/metrics/__init__.py
Expand Up @@ -82,6 +82,9 @@
from ._plot.precision_recall_curve import plot_precision_recall_curve
from ._plot.precision_recall_curve import PrecisionRecallDisplay

from ._plot.confusion_matrix import plot_confusion_matrix
from ._plot.confusion_matrix import ConfusionMatrixDisplay


__all__ = [
'accuracy_score',
Expand All @@ -97,6 +100,7 @@
'cluster',
'cohen_kappa_score',
'completeness_score',
'ConfusionMatrixDisplay',
'confusion_matrix',
'consensus_score',
'coverage_error',
Expand Down Expand Up @@ -137,6 +141,7 @@
'pairwise_distances_argmin_min',
'pairwise_distances_chunked',
'pairwise_kernels',
'plot_confusion_matrix',
'plot_precision_recall_curve',
'plot_roc_curve',
'PrecisionRecallDisplay',
Expand Down

0 comments on commit e650a20

Please sign in to comment.