Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class mismatch in skplt.plot_confusion_matrix when test has fewer classes than training #28

Closed
ArmandGiraud opened this issue Apr 30, 2017 · 4 comments

Comments

@ArmandGiraud
Copy link

ArmandGiraud commented Apr 30, 2017

Hello,
I have an issue when trying to plot a confusion matrix fewer classes in my test set than in training.
The class with 12 000+ occcurences in my sample should be labelled 'O'
is it possible to get around this, or to include the label set manually as an input?

image
it's not a big issue but would be nice if we could fix it.
Thanks for your help

@reiinakano
Copy link
Owner

Hi @ArmandGiraud , could you give small sample code demonstrating the problem?

@ArmandGiraud
Copy link
Author

Hi @reiinakano i tried to reproduce the issue with the digits dataset

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from scikitplot import plotters as skplt
import pandas as pd
digits = load_digits()
%matplotlib inline

X = digits.data
y = digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
test = pd.DataFrame(X_test)
test['target'] = y_test

test_reduced = test.ix[test.target!=0,:]
X_test_reduced = test_reduced.drop('target',axis=1)
y_test_reduced = test_reduced.target
print(set(y_test_reduced)) # we removed 0 from test set. 
lr = LogisticRegression(tol=5) # i set high tolerance so that the classifier sitll predicts some 0's
lr.fit(X_train,y_train)


y_pred  = lr.predict(X_test_reduced)
skplt.plot_confusion_matrix(y_test_reduced, y_pred)

image
From the matrix we can read that the true values contains no occurrences of 1, but it actually does, this first line should refer to 0.

sorry for the ugly syntax, i'm kind of new to python.
thx

@reiinakano
Copy link
Owner

reiinakano commented Apr 30, 2017

Hi @ArmandGiraud , you're absolutely right, this was a bug in the implementation. Rest assured this has been fixed in #29 and is now in the v0.2.5 release. Just run pip install scikit-plot --upgrade and you should be good to go. :)

Thanks for using scikit-plot!

@ArmandGiraud
Copy link
Author

@reiinakano
Thanks a lot, that was a fast fix!

Thank you as well for developping this useful package, it saves a lot of time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants