In [4]:
# jupyter notebook for creating confusion matrix plot
# Date: 02/10/2020

# import libraries
import pandas as pd
from sklearn import ensemble
import matplotlib.pyplot as plt
import numpy as np
import itertools
from sklearn import metrics

In [None]:
# function for ploting confusion matrix
def plot_confusion_matrix(confusion_matrix, 
                          class_names,
                          normalize=False,
                          title='confusion matrix',
                          color_map=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(confusion_matrix, interpolation='nearest', cmap=color_map)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names)
    plt.yticks(tick_marks, class_names)

    if normalize:
        confusion_matrix = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(confusion_matrix)

    thresh = confusion_matrix.max() / 2.
    for i, j in itertools.product(range(confusion_matrix.shape[0]), range(confusion_matrix.shape[1])):
        plt.text(j, i, confusion_matrix[i, j],
                 horizontalalignment="center",fontsize=8,
                 color="black" if confusion_matrix[i, j] > thresh else "white")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [2]:
# load true labels and predicted labels
test = pd.read_csv('data/test/test.csv') # true labels
test_label = test['Cover_Type']
tree = pd.read_csv('data/output/finalSubmission.csv') # predicted labels
predict_label = tree['Cover_Type']

final_confusion_matrix = metrics.confusion_matrix(test_label, predict_label) # use sklearn package
class_names = ['1','2','3','4','5','6','7']

In [None]:
metrics.ConfusionMatrixDisplay(final_confusion_matrix, class_names).plot()

In [None]:
# draw confusion matrix
plt.figure()
plot_confusion_matrix(final_confusion_matrix, class_names=class_names,
                      title='Confusion matrix')