Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

[MRG] Modified the confusion matrix example #3454

Merged
merged 1 commit into from

6 participants

@ldirer

No description provided.

examples/model_selection/plot_confusion_matrix.py
((9 lines not shown))
y_pred = classifier.fit(X_train, y_train).predict(X_test)
# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
+
+def normalize_by_support(m):
+ # Normalize a matrix by row
+ m = m.astype('float')
+ m /= m.sum(axis=1)
+ return m
+
+# Normalize the confusion matrix by row (i.e by the number of samples
+# in each class)
+cm = normalize_by_support(cm)
@arjoly Owner
arjoly added a note

You need to create a function for that :-)
Since it's used only once.

@ogrisel Owner
ogrisel added a note
cm_normalized = cm.asfloat() / cm.sum(axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@coveralls

Coverage Status

Coverage remained the same when pulling 5c66f05 on ldirer:confusion_matrix_example into 8dab222 on scikit-learn:master.

@arjoly
Owner

I would add what a row and a column means in the confusion matrix.

@ldirer ldirer changed the title from [MRG] Modified the confusion matrix example to [WIP] Modified the confusion matrix example
@coveralls

Coverage Status

Coverage remained the same when pulling 856059d on ldirer:confusion_matrix_example into 8dab222 on scikit-learn:master.

examples/model_selection/plot_confusion_matrix.py
@@ -10,10 +10,21 @@
off-diagonal elements are those that are mislabeled by the
classifier. The higher the diagonal values of the confusion
matrix the better, indicating many correct predictions.
+
+Here the results are not as good as they could be as our
+choice for the regularization parameter C was not the best.
+In real life applications this parameter is usually chosen
+using grid_search.
@ogrisel Owner
ogrisel added a note

You can use :ref:grid_search to get the online example add a link to the model selection section of the narrative documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
examples/model_selection/plot_confusion_matrix.py
@@ -10,10 +10,21 @@
off-diagonal elements are those that are mislabeled by the
classifier. The higher the diagonal values of the confusion
matrix the better, indicating many correct predictions.
+
+Here the results are not as good as they could be as our
+choice for the regularization parameter C was not the best.
+In real life applications this parameter is usually chosen
+using grid_search.
+
+.. note::
+
+ See also :ref:`grid_search`
@ogrisel Owner
ogrisel added a note

Alright... then this would be redundant. Please collapse the link in the previous paragraph.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
examples/model_selection/plot_confusion_matrix.py
((9 lines not shown))
y_pred = classifier.fit(X_train, y_train).predict(X_test)
# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
+# Normalize the confusion matrix by row (i.e by the number of samples
+# in each class)
+cm = cm.astype('float')
+cm /= cm.sum(axis=1)
@ogrisel Owner
ogrisel added a note

This can fit in one line and it's better to store the result in a new variable called cm_normalized. Then do the 2 plots: one for the absolute CM and one for the normalized CM.

Use plt.figure() at the beginning of each plot (and plt.show() at the end of each plot).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
examples/model_selection/plot_confusion_matrix.py
@@ -10,10 +10,21 @@
off-diagonal elements are those that are mislabeled by the
classifier. The higher the diagonal values of the confusion
matrix the better, indicating many correct predictions.
+
+Here the results are not as good as they could be as our
+choice for the regularization parameter C was not the best.
@jnothman Owner

Why do you not just choose a better C for this example?

@jnothman Owner

Ah. I see your note below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@jnothman
Owner

Is there great benefit to normalizing the matrix? Surely it could be meaningfully normalized by row or by column, but (until visualised or similar) the matrix is more informative with actual counts in it...

@ldirer

As you mentioned the normalization is more about the visualization of the matrix.
In the current example, I find it a little bit weird that a class perfectly classified (like the '0' class) is not displayed with the highest "color-value" in the plot.

@arjoly
Owner

Is it in MRG or WIP state?

@ldirer ldirer changed the title from [WIP] Modified the confusion matrix example to [MRG] Modified the confusion matrix example
@ldirer ldirer changed the title from [MRG] Modified the confusion matrix example to [WIP] Modified the confusion matrix example
@ldirer ldirer changed the title from [WIP] Modified the confusion matrix example to [MRG] Modified the confusion matrix example
@ogrisel
Owner

Please rotate the xticks label to make them vertical so that this utility function can be quickly adapted to work on problems with a larger amount of classes.

plt.tight_layout() might help solve issues where xlabel and xticks labels overlap. If not you can also have a look at:

http://stackoverflow.com/questions/6705581/rotating-xticks-causes-the-ticks-partially-hidden-in-matplotlib

@ogrisel
Owner

The rotation is only applied to one of the 2 figures. Furthermore it's truncated:

truncated_cm

I changed the state of the PR back to WIP in the mean time.

@ogrisel ogrisel changed the title from [MRG] Modified the confusion matrix example to [WIP] Modified the confusion matrix example
@ldirer

Sorry I pushed mechanically without thinking it was in MRG state and you would review it.

I now included the rotation for both figures and tried to address the truncation issue.
However I cannot reproduce the truncation on my side so I am unsure whether what I changed is enough.

@ldirer ldirer changed the title from [WIP] Modified the confusion matrix example to [MRG] Modified the confusion matrix example
@ogrisel
Owner

It's a bit too spaced (right and bottom) on my box now. But it probably depends on the matplotlib GUI renderer. Can you please check of the default layout is good when you build the documentation (cd doc && make html)? It will take some time if it's the first time you build the scikit-learn documentation though.

@ogrisel
Owner

Actually I with plt.tight_layout() only it's good. Here is what I changed:

diff --git a/examples/model_selection/plot_confusion_matrix.py b/examples/model_selection/plot_confusion_matrix.py
index 5955d78..357f9e9 100644
--- a/examples/model_selection/plot_confusion_matrix.py
+++ b/examples/model_selection/plot_confusion_matrix.py
@@ -62,7 +62,7 @@ plt.ylabel('True label')
 plt.xlabel('Predicted label')
 # Convenience function to adjust plot parameters for a clear layout.
 plt.tight_layout()
-plt.subplots_adjust(bottom=0.3)

 # Normalize the confusion matrix by row (i.e by the number of samples
 # in each class)
@@ -81,5 +81,5 @@ plt.yticks(tick_marks, iris.target_names)
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.tight_layout()
-plt.subplots_adjust(bottom=0.3)
 plt.show()
@ogrisel
Owner

Please remove the plt.subplots_adjust calls and we are good to go.

@arjoly
Owner

Can you show the generated figure? Otherwise +1 for the code and explanation.

@ldirer

Unfortunately I can't get sphinx to build the documentation for me (cd doc && make html).
I tried various versions, settled for 1.2.2 which claims there is "No module named sklearn.externals.six".
I ran make in the top folder before trying again but it did not work.
From what I read it could be an issue with imported version being different from the build version, but I am not too sure how to solve it.

ogrisel's comment suggested he built the documentation and the figure was fine.
Can you confirm @ogrisel?
I will try my luck with Sphinx again, but it might be better if the pull request does not have to wait for me.

@GaelVaroquaux
@ogrisel
Owner

@ldirer here is the typical sequence to make sure that scikit-learn is installed in your Python site-packages in "editable" mode, meaning that it will use the current source fold built with python setup.py build_ext -i:

cd scikit-learn
pip uninstall scikit-learn  # several times untill no more scikit-learn is found
pip install --editable scikit-learn
(cd doc && make html)
examples/model_selection/plot_confusion_matrix.py
((20 lines not shown))
plt.title('Confusion matrix')
plt.colorbar()
+tick_marks = np.arange(len(iris.target_names))
+plt.xticks(tick_marks, iris.target_names, rotation=60)
+plt.yticks(tick_marks, iris.target_names)
+plt.ylabel('True label')
+plt.xlabel('Predicted label')
+# Convenience function to adjust plot parameters for a clear layout.
+plt.tight_layout()
+
+# Normalize the confusion matrix by row (i.e by the number of samples
+# in each class)
+cm_normalized = cm.astype('float')/cm.sum(axis=1)
@ogrisel Owner
ogrisel added a note

pep8: spaces around "/"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@ogrisel
Owner

Also maybe the Blues colormap would look more beautiful while still being monochrome hence readable. http://wiki.scipy.org/Cookbook/Matplotlib/Show_colormaps

@ogrisel
Owner

Actually maybe it would be interesting to use a MultinomialNB model (because it's fast) on a the 20 newsgroup dataset, randomly subsampled to introduce some class imbalance (e.g. take between 30% and 100% of the data for each class) to better show the interest of this kind of visualization.

@ogrisel
Owner

But the 20 newsgroups stuff could be addressed in a separate PR, just the current state (addressing the last batch of nitpicks) is already a great improvement over the current state of the example that uses the infamous jet colormap.

@coveralls

Coverage Status

Coverage increased (+0.14%) when pulling c3c9606 on ldirer:confusion_matrix_example into 7a7fca8 on scikit-learn:master.

@ldirer

Thanks for your advice, I finally got sphinx to work and build the doc.
I think I had an installation problem at first and then I was stuck on issue 3475.

The figure is not truncated when I build the documentation:
image

@coveralls

Coverage Status

Coverage remained the same when pulling f505221 on ldirer:confusion_matrix_example into c0afd46 on scikit-learn:master.

@ogrisel
Owner

Thanks it looks great I will merge by rebase.

@ogrisel ogrisel merged commit e181932 into from
@ogrisel
Owner

It was already rebased, I merged. Thanks again @ldirer!

@jnothman
Owner

The confusion matrix example is now rendering poorly at http://scikit-learn.org/dev/auto_examples/model_selection/plot_confusion_matrix.html. @ldirer, would you like to offer a fix?

@ldirer

Hello,
Indeed the rendering is quite disappointing.
I tried to look into it but I don't really know where I can find what causes the plot to render so poorly.

The html produced is fine when I build the doc locally. Do you have any advice on where I could try and look to find the cause of the issue?

I would at least need to be able to see the (real) end result to propose a fix.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Commits on Jul 27, 2014
  1. @ldirer

    Modified the confusion matrix example: included a normalized matrix, …

    ldirer authored
    …changed the colors and added class labels.
This page is out of date. Refresh to see the latest.
Showing with 44 additions and 4 deletions.
  1. +44 −4 examples/model_selection/plot_confusion_matrix.py
View
48 examples/model_selection/plot_confusion_matrix.py
@@ -10,10 +10,24 @@
off-diagonal elements are those that are mislabeled by the
classifier. The higher the diagonal values of the confusion
matrix the better, indicating many correct predictions.
+
+The figures show the confusion matrix with and without
+normalization by class support size (number of elements
+in each class). This kind of normalization can be
+interesting in case of class imbalance to have a more
+visual interpretation of which class is being misclassified.
+
+Here the results are not as good as they could be as our
+choice for the regularization parameter C was not the best.
+In real life applications this parameter is usually chosen
+using :ref:`grid_search`.
+
"""
print(__doc__)
+import numpy as np
+
from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix
@@ -28,19 +42,45 @@
# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
-# Run classifier
-classifier = svm.SVC(kernel='linear')
+# Run classifier, using a model that is too regularized (C too low) to see
+# the impact on the results
+classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)
# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
-
+print('Confusion matrix, without normalization')
print(cm)
# Show confusion matrix in a separate window
-plt.matshow(cm)
+plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
plt.title('Confusion matrix')
+plt.set_cmap('Blues')
+plt.colorbar()
+tick_marks = np.arange(len(iris.target_names))
+plt.xticks(tick_marks, iris.target_names, rotation=60)
+plt.yticks(tick_marks, iris.target_names)
+plt.ylabel('True label')
+plt.xlabel('Predicted label')
+# Convenience function to adjust plot parameters for a clear layout.
+plt.tight_layout()
+
+# Normalize the confusion matrix by row (i.e by the number of samples
+# in each class)
+cm_normalized = cm.astype('float') / cm.sum(axis=1)
+
+print('Normalized confusion matrix')
+print(cm_normalized)
+
+# Show normalized confusion matrix in a separate window
+plt.figure()
+plt.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.binary)
+plt.title('Normalized confusion matrix')
+plt.set_cmap('Blues')
plt.colorbar()
+plt.xticks(tick_marks, iris.target_names, rotation=60)
+plt.yticks(tick_marks, iris.target_names)
plt.ylabel('True label')
plt.xlabel('Predicted label')
+plt.tight_layout()
plt.show()
Something went wrong with that request. Please try again.