In [0]:
from sklearn import datasets, svm
import numpy as np
import pandas as pd
import copy
import pickle as pk
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score
from imblearn.metrics import specificity_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from scipy import interp
from matplotlib import pyplot as plt
import seaborn as sns

In [0]:
# 'micro': Calculate metrics globally by counting the total true positives, false negatives and false positives.
# 'macro': Calculate metrics for each label, and find their unweighted mean
# 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters ‘macro’ to account for label imbalance; it can result in an F-score that is not between precision and recall.

In [0]:
# # Import some data to play with
# dataset = datasets.load_iris()
# #dataset = datasets.load_breast_cancer()
# X = dataset.data
# y = dataset.target
# classes_label = datasets

# # Binarize the output
# y = label_binarize(y, classes=list(range(dataset.target_names.__len__())))
# n_classes = y.shape[1]

# if n_classes + 1 == 2:
#     y = np.append(y, 1-y,axis=1)

# # Add noisy features to make the problem harder
# random_state = np.random.RandomState(0)
# # n_samples, n_features = X.shape
# # X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]

# # shuffle and split training and test sets
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.8,
#                                                     random_state=0)

# # Learn to predict each class against the other
# classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
#                                  random_state=random_state))
# y_score = classifier.fit(X_train, y_train).decision_function(X_test)

# y_pred = classifier.predict(X_test).argmax(axis=1)
# y_true = y_test.argmax(axis=1)

# with open('test.pkl', 'rb') as f:
#     data = pk.load(f)

# y_score = data['y_score']
# y_pred = data['y_pred']
# y_true = data['y_true']
# classes_labels = data['classes_labels']

In [0]:
def plot_confusion_matrix(confusion_mat, 
                          classes, 
                          figure_axis,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    """
    
    # Compute confusion matrix
    cm = confusion_mat

    ax = figure_axis
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.grid(False)
    ax.figure.colorbar(im, ax=ax)
    
    if classes is not None:
        # We want to show all ticks...
        ax.set(xticks=np.arange(cm.shape[1]),
               yticks=np.arange(cm.shape[0]),
               # ... and label them with the respective list entries
               xticklabels=classes, yticklabels=classes,
               ylabel='True Label',
               xlabel='Predicted Label')

        # Rotate the tick labels and set their alignment.
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")
    else:
        plt.setp( ax.get_xticklabels(), visible=False)
        plt.setp( ax.get_yticklabels(), visible=False)
        plt.setp( ax.get_xticklines(), visible=False)
        plt.setp( ax.get_yticklines(), visible=False)
        
    ax.set(title=title)

    # Loop over data dimensions and create text annotations.
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, cm[i, j],
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")


def plot_confusion_table_sample(figure_axis, cmap=plt.cm.Blues):
    cm = np.array([[1, 0.5],
                   [0.5, 1]])
    
    labels = [['TP', 'FP'],
              ['FN', 'TN']]
    
    ax=figure_axis
    ax.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.grid(False)
    plt.setp( ax.get_xticklabels(), visible=False)
    plt.setp( ax.get_yticklabels(), visible=False)
    plt.setp( ax.get_xticklines(), visible=False)
    plt.setp( ax.get_yticklines(), visible=False)
    
    # Loop over data dimensions and create text annotations.
    thresh = cm.max() / 2.
    
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, labels[i][j],
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")


In [0]:
class ClassifierReport:

    def __init__(self, y_true, y_pred, y_score , number_of_classes, average_type = 'macro', 
                 digits_count_fp = 3,classes_labels = None):
        """ Initialization function
            y_true: list or numpy array (number of samples)
            y_pred: list or numpy array (number of samples)
            y_score: numpy array contains the actual outputs before decision (number of samples, number of classes)
            average_type: determine how to calculate the overall metrics
            digits_count_fp: number of digits after the floating point
            classes_labels: list of classes labels
        """
        
        self.y_true = np.array(y_true)
        self.y_pred = np.array(y_pred)
        self.y_score = np.array(y_score)
        self.number_of_classes = number_of_classes
        self.y_true_one_hot = label_binarize(y_true, classes=list(range(self.number_of_classes)))
        
        self.number_of_samples = len(self.y_true)
        self.number_of_samples_per_class = [(self.y_true==c).sum() 
                                for c in range(self.number_of_classes)]
        
        self.classes_labels = ['Class ' + str(c) for c in range(self.number_of_classes)] \
                                if classes_labels is None \
                                else classes_labels
        
        self.digits_count_fp = digits_count_fp
        self.average_type = average_type
        
        self.TP_TN_FP_FN()
        self.calculate_confusion_matrix()
        self.calculate_confusion_tables()
        self.accuracy()
        self.recall()
        self.precision()
        self.f1_score()
        self.specificity()
        self.cohen_kappa()
        self.calculate_roc_auc()
        
        
    def TP_TN_FP_FN(self):
        self.TP = np.zeros(self.number_of_classes)
        self.FP = np.zeros(self.number_of_classes)
        self.TN = np.zeros(self.number_of_classes)
        self.FN = np.zeros(self.number_of_classes)
        
        for cls in range(self.number_of_classes):
            # Calculate
            self.TP[cls] = (self.y_pred[self.y_true == cls] == cls).sum()
            self.FN[cls] = (self.y_pred[self.y_true == cls] != cls).sum()
            
            self.TN[cls] = (self.y_pred[self.y_true != cls] != cls).sum()
            self.FP[cls] = (self.y_pred[self.y_true != cls] == cls).sum()            
    
    def calculate_confusion_matrix(self):
        """ Function to calculate confusion matrix and weighted confusion matrix """
        self.confusion_matrix = confusion_matrix(self.y_true, self.y_pred)
        
        classes_weights = np.array(self.number_of_samples_per_class).reshape(
            self.number_of_classes, 1)
        
        self.normalized_confusion_matrix = (self.confusion_matrix/classes_weights).round(self.digits_count_fp)
    
    def calculate_confusion_tables(self):
        """ Function to calculate confusion table and weighted confusion table 
            for each class
        """
        self.confusion_tables = np.zeros((self.number_of_classes, 2, 2))
        self.normalized_confusion_tables = np.zeros((self.number_of_classes, 2, 2))
        
        for cls in range(self.number_of_classes):
            # Normal confusion table
            self.confusion_tables[cls, 0, 0] = self.TP[cls] # TP
            self.confusion_tables[cls, 0, 1] = self.FP[cls] # FP
            self.confusion_tables[cls, 1, 0] = self.FN[cls] # FN
            self.confusion_tables[cls, 1, 1] = self.TN[cls] # TN
            
            # Weighted confusion table
            table_weights = self.confusion_tables[cls].sum(axis=0).reshape(1, 2)
            self.normalized_confusion_tables[cls] = (self.confusion_tables[cls]/table_weights).round(self.digits_count_fp)
        
        # Convert the data type into int
        self.confusion_tables = self.confusion_tables.astype(int)
    
    def accuracy(self, sample_weight = None):
        """ Refer to sklearn for full doc"""
        if sample_weight is None:
            sample_weight  = np.ones(self.number_of_samples)
        self.overall_accuracy = accuracy_score(
            self.y_true, self.y_pred, sample_weight=sample_weight).round(self.digits_count_fp)
        
    def recall(self):
        """Recall is also known as Sensitivity and True Positive Rate"""
        self.overall_recall = recall_score(
            self.y_true, self.y_pred, average = self.average_type).round(self.digits_count_fp)
        self.classes_recall = recall_score(
            self.y_true, self.y_pred, average = None).round(self.digits_count_fp)
        
    def precision(self):
        """ Precision or Positive Predictive Value """
        self.overall_precision = precision_score(
            self.y_true, self.y_pred, average = self.average_type).round(self.digits_count_fp)
        self.classes_precision = precision_score(
            self.y_true, self.y_pred, average = None).round(self.digits_count_fp)
    
    def f1_score(self):
        """ f1_score is harmonic mean of recall and precision"""
        self.overall_f1_score = f1_score(
            self.y_true, self.y_pred, average = self.average_type).round(self.digits_count_fp)
        self.classes_f1_score = f1_score(
            self.y_true, self.y_pred, average = None).round(self.digits_count_fp)
    
    
    def specificity(self):
        """ Specificity is also known as True Negative Rate """
        self.overall_specificity = specificity_score(
            self.y_true, self.y_pred, average = self.average_type).round(self.digits_count_fp)
        self.classes_specificity = specificity_score(
            self.y_true, self.y_pred, average = None).round(self.digits_count_fp)
    
    def cohen_kappa(self):
        self.overall_cohen_kappa = cohen_kappa_score(self.y_true, self.y_pred).round(self.digits_count_fp)    
        
    def calculate_roc_auc(self):
        """ Refer to : https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
        """
        if self.number_of_classes == 2:
            self.fpr, self.tpr, self.thresholds = roc_curve(self.y_true, self.y_score[:,1])
            self.roc_auc = auc(self.fpr, self.tpr).round(self.digits_count_fp)
            
            # Rounding
            self.fpr = self.fpr.round(self.digits_count_fp)
            self.tpr = self.tpr.round(self.digits_count_fp)
            self.thresholds = self.thresholds.round(self.digits_count_fp)
            return 
        
        # Compute ROC curve and ROC area for each class
        self.fpr = dict()    # fpr: False positive rate
        self.tpr = dict()    # tpr: True positive rate
        self.roc_auc = dict()
        for i in range(self.number_of_classes):
            self.fpr[i], self.tpr[i], _ = roc_curve(
                self.y_true_one_hot[:, i], self.y_score[:, i])
            self.roc_auc[i] = auc(self.fpr[i] , self.tpr[i]).round(
                self.digits_count_fp).round(self.digits_count_fp)
            
            # Rounding
            self.fpr[i] = self.fpr[i].round(self.digits_count_fp)
            self.tpr[i] = self.tpr[i].round(self.digits_count_fp)


        # Compute micro-average ROC curve and ROC area
        self.fpr["micro"], self.tpr["micro"], _ = roc_curve(
            self.y_true_one_hot.ravel(), self.y_score.ravel())
        self.roc_auc["micro"] = auc(self.fpr["micro"], self.tpr["micro"]).round(self.digits_count_fp)
        
        # Rounding
        self.fpr["micro"] = self.fpr["micro"].round(self.digits_count_fp)
        self.tpr["micro"] = self.tpr["micro"].round(self.digits_count_fp)
        
        # Compute macro-average ROC curve and ROC area
        
        # First aggregate all false positive rates
        all_fpr = np.unique(np.concatenate([self.fpr[i] for i in range(self.number_of_classes)]))

        # Then interpolate all ROC curves at this points
        mean_tpr = np.zeros_like(all_fpr)
        for i in range(self.number_of_classes):
            mean_tpr += interp(all_fpr, self.fpr[i], self.tpr[i])

        # Finally average it and compute AUC
        mean_tpr /= self.number_of_classes
        
        # Rounding
        self.fpr["macro"] = all_fpr.round(self.digits_count_fp)
        self.tpr["macro"] = mean_tpr.round(self.digits_count_fp)
        self.roc_auc["macro"] = auc(self.fpr["macro"], self.tpr["macro"]).round(self.digits_count_fp)
        
    def show_confusion_matrix(self):
        
        self.calculate_confusion_matrix()
        # fig = plt.figure(figsize=(8,3))
        fig = plt.figure(figsize=(10,4))
        ax = plt.subplot(1,2,1)
        plot_confusion_matrix(self.confusion_matrix, 
                              self.classes_labels,
                              title = 'Confusion Matrix',
                              cmap = plt.cm.Blues,
                              figure_axis = ax)
        
        ax = plt.subplot(1,2,2)
        plot_confusion_matrix(self.normalized_confusion_matrix, 
                              self.classes_labels,
                              title = 'Normalized Confusion Matrix',
                              cmap = plt.cm.Blues,
                              figure_axis = ax)
        fig.tight_layout()
        plt.show()
    
    def show_confusion_tables(self):
        self.calculate_confusion_tables()
        fig = plt.figure(figsize=(8,self.number_of_classes*2))
        table_counter = 0
        for cls in range(self.number_of_classes):
            table_counter += 1
            ax = plt.subplot(self.number_of_classes, 3, table_counter)
            plt.grid(False)
            plot_confusion_table_sample(ax)
            
            table_counter += 1
            ax = plt.subplot(self.number_of_classes, 3, table_counter)
            plot_confusion_matrix(self.confusion_tables[cls], 
                              None,
                              title = self.classes_labels[cls],
                              cmap = plt.cm.Blues,
                              figure_axis = ax)
            
            table_counter += 1
            ax = plt.subplot(self.number_of_classes, 3, table_counter)
            plot_confusion_matrix(self.normalized_confusion_tables[cls], 
                              None,
                              title = self.classes_labels[cls],
                              cmap = plt.cm.Blues,
                              figure_axis = ax)
        fig.tight_layout()
        plt.show()
        
    def show_overall_metrics(self, required_metrics = ['Accuracy', 'Recall', 'Precision', 'F1_score', 'Specificity', 'Cohen_Kappa']):
        # Overall
        overall_metrics = {
            'Accuracy': self.overall_accuracy,
            'Recall': self.overall_recall,
            'Precision': self.overall_precision,
            'F1_score': self.overall_f1_score,
            'Specificity': self.overall_specificity,
            'Cohen_Kappa': self.overall_cohen_kappa
            }
        
        overall_metrics = pd.Series(
                {m: overall_metrics[m] for m in required_metrics
                                       if m in overall_metrics.keys()})
        
        # Show overall metrics
        fig = plt.figure(figsize=(10, 5))
        ax = plt.subplot(1,2,1)
        overall_metrics.plot.bar()
        plt.title('Overall Metrics')
        plt.grid(True)
        plt.ylim((0,1.5))
        
        ax = plt.subplot(1,2,2)
        ax.axis('tight')
        ax.axis('off')
        table = ax.table(cellText = overall_metrics.values.reshape((len(overall_metrics),1)), 
                     colLabels = ['Values'], 
                     rowLabels = overall_metrics.index,
                     loc = 'center',
                     colWidths = [0.2]
                    )
        table.scale(2, 2)
        table.set_fontsize(12)
        plt.show()
        
    def show_classes_metrics(self, required_metrics = [ 'Recall', 'Precision', 'F1_score', 'Specificity', 'Cohen_Kappa']):
        # Per class
        classes_metrics = {
            'Recall': self.classes_recall,
            'Precision': self.classes_precision,
            'F1_score': self.classes_f1_score,
            'Specificity': self.classes_specificity
            }
        
        classes_metrics = pd.DataFrame(
            {m: classes_metrics[m] for m in required_metrics
                                   if m in classes_metrics.keys()})
        classes_metrics.index = self.classes_labels
        
        #fig = plt.figure(figsize=(10,5))
        # Data table
        ax = plt.subplot(1, 1, 1)
        ax.axis('tight')
        ax.axis('off')
        table = ax.table(cellText = classes_metrics.values, 
                     colLabels = classes_metrics.columns, 
                     rowLabels = classes_metrics.index,
                     loc = 'center',
                     colWidths = [0.2]*len(classes_metrics.columns))
        
        table.scale(2, 2)
        table.set_fontsize(15)
        plt.grid(True)
        plt.show()
        
        # Compare between metrics
        classes_metrics.plot.bar()
        plt.title('Metrics Comparison')
        plt.grid(True)
        plt.legend(loc=0)
        plt.ylim((0,1.5))
        plt.show()
        
        # Compare between classes
        #plt.subplot(3, 1, 3)
        classes_metrics.T.plot.bar()
        plt.title('Classes Comparison')
        plt.grid(True)
        plt.ylim((0,1.5))
        plt.show()
    
    def show_roc_curve(self):
        # Plot all ROC curves
        plt.figure(figsize=(7, 5))
        
        lw = 2
        plt.plot([0, 1], [0, 1], 'k--', lw=lw)
        if self.number_of_classes == 2:
            plt.plot(self.fpr, self.tpr,
                 label = 'ROC curve (area = {0:0.2f})'
                 ''.format(self.roc_auc),
                 color = 'crimson', linestyle = ':', linewidth = 4)
        else:
            
            # Micro average ROC
            plt.plot(self.fpr["micro"], self.tpr["micro"],
                 label='micro-average ROC curve (area = {0:0.2f})'
                 ''.format(self.roc_auc["micro"]),
                 color='deeppink', linestyle=':', linewidth=4)
        
            # Macro average ROC
            plt.plot(self.fpr["macro"], self.tpr["macro"],
                 label='macro-average ROC curve (area = {0:0.2f})'
                 ''.format(self.roc_auc["macro"]),
                 color = 'navy', linestyle = ':', linewidth = 4)
        
            # Classes ROC
        
            for i in range(self.number_of_classes):
                plt.plot(self.fpr[i], self.tpr[i],  lw=lw,
                 label='ROC curve of ({0}) (area = {1:0.2f})'
                 ''.format(self.classes_labels[i], self.roc_auc[i]))

            
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curves')
        plt.legend(loc="lower right")
        plt.show()
    
    def show_all(self):
        new_lines = '\n'
        
        print('Confusion Matrix' + new_lines)
        self.show_confusion_matrix()
        
        print(new_lines + 'Confusion Tables' + new_lines)
        self.show_confusion_tables()
        
        print(new_lines + 'Overall Metrics' + new_lines)
        self.show_overall_metrics()
        
        print(new_lines + 'Classes Metrics' + new_lines)
        self.show_classes_metrics()
        
        print(new_lines + 'ROC Curve' + new_lines)
        self.show_roc_curve()

In [0]:
# a = ClassifierReport(y_true, y_pred, y_score, 
#                       classes_labels=dataset.target_names.tolist())
# a.show_all()

In [0]:
# a = ClassifierReport(y_true, y_pred, y_score,
#                       classes_labels=classes_labels)
# a.show_all()

In [0]:
# print(np.array(a.TP, dtype=int))
# print(np.array(a.FP, dtype=int))
# print(np.array(a.FN, dtype=int))
# print(np.array(a.TN, dtype=int))

In [0]:
print('Importing Done ...')