Example of confusion matrix usage to evaluate the quality of the output of a classifier on the iris data set. The diagonal elements represent the number of points for which the predicted label is equal to the true label, while off-diagonal elements are those that are mislabeled by the classifier. The higher the diagonal values of the confusion matrix the better, indicating many correct predictions.

The figures show the confusion matrix with and without normalization by class support size (number of elements in each class). This kind of normalization can be interesting in case of class imbalance to have a more visual interpretation of which class is being misclassified.

Here the results are not as good as they could be as our choice for the regularization parameter C was not the best. In real life applications this parameter is usually chosen using [Tuning the hyper-parameters of an estimator](http://scikit-learn.org/stable/modules/grid_search.html#grid-search).

#### New to Plotly?
Plotly's Python library is free and open source! [Get started](https://plot.ly/python/getting-started/) by downloading the client and [reading the primer](https://plot.ly/python/getting-started/).
<br>You can set up Plotly to work in [online](https://plot.ly/python/getting-started/#initialization-for-online-plotting) or [offline](https://plot.ly/python/getting-started/#initialization-for-offline-plotting) mode, or in [jupyter notebooks](https://plot.ly/python/getting-started/#start-plotting-online).
<br>We also have a quick-reference [cheatsheet](https://images.plot.ly/plotly-documentation/images/python_cheat_sheet.pdf) (new!) to help you get started!

### Version

In [1]:
import sklearn
sklearn.__version__

'0.18.1'

### Imports

This tutorial imports [train_test_split](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html#sklearn.model_selection.train_test_split) and [confusion_matrix](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html#sklearn.metrics.confusion_matrix).

In [2]:
print(__doc__)

import plotly.plotly as py
import plotly.graph_objs as go

import itertools
import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

Automatically created module for IPython interactive environment


### Calculations

In [3]:
# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

    
def matplotlib_to_plotly(cmap, pl_entries):
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale
cmap = matplotlib_to_plotly(plt.cm.Blues, 5)

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=cmap):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    anno = []
    trace = go.Heatmap(z=cm, colorscale=cmap,
                       showscale=False)
    
    
    
    tick_marks = np.arange(len(classes))
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        anno.append(dict(x=j, y=i, text=cm[i, j], showarrow=False))

    
    layout = go.Layout(title=title, annotations=anno,
                       xaxis=dict(title='True label', 
                                  tickvals=tick_marks,
                                  ticktext=classes),
                       yaxis=dict(title='Predicted label',
                                  tickvals=tick_marks,
                                  ticktext=classes))
                    
    fig = go.Figure(data=[trace], layout=layout)
                    
    return fig

### Plot Results

In [4]:
# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

### Plot Non-Normalized Confusion Matrix

In [5]:
fig = plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')


Confusion matrix, without normalization
[[13  0  0]
 [ 0 10  6]
 [ 0  0  9]]


In [6]:
py.iplot(fig)

### Plot Normalized Confusion Matrix

In [7]:
fig = plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                            title='Normalized confusion matrix')

Normalized confusion matrix
[[ 1.    0.    0.  ]
 [ 0.    0.62  0.38]
 [ 0.    0.    1.  ]]


In [8]:
py.iplot(fig)

In [10]:
from IPython.display import display, HTML

display(HTML('<link href="//fonts.googleapis.com/css?family=Open+Sans:600,400,300,200|Inconsolata|Ubuntu+Mono:400,700" rel="stylesheet" type="text/css" />'))
display(HTML('<link rel="stylesheet" type="text/css" href="http://help.plot.ly/documentation/all_static/css/ipython-notebook-custom.css">'))

! pip install git+https://github.com/plotly/publisher.git --upgrade
import publisher
publisher.publish(
    'Confusion Matrix.ipynb', 'scikit-learn/grid-search-digits/', 'Confusion Matrix | plotly',
    ' ',
    title = 'Confusion Matrix | plotly',
    name = 'Confusion Matrix',
    has_thumbnail='true', thumbnail='thumbnail/confusion-matrix.jpg', 
    language='scikit-learn', page_type='example_index',
    display_as='model_selection', order=7,
    ipynb= '~Diksha_Gabha/3425')

Collecting git+https://github.com/plotly/publisher.git
  Cloning https://github.com/plotly/publisher.git to /tmp/pip-dDBom2-build
Installing collected packages: publisher
  Found existing installation: publisher 0.10
    Uninstalling publisher-0.10:
      Successfully uninstalled publisher-0.10
  Running setup.py install for publisher ... [?25l- done
[?25hSuccessfully installed publisher-0.10
