## 1. Classification evaluation

### 1) Confusion matrix

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def confusion_mat(probs, y):
    pred = np.argmax(probs, axis=1)
    pred_array = np.zeros((2, 2))
    for i in range(len(y)):
        pred_array[pred[i], y[i]] += 1
    labels = ['Non-cancer', 'Cancer']
    
    fig, ax = plt.subplots(figsize=[10,10], squeeze=True)
    im = ax.imshow(pred_array, cmap='YlGn')
    axins = inset_axes(ax,
                   width="5%",
                   height="50%",
                   loc='lower left',
                   bbox_to_anchor=(1.05, 0., 1, 1),
                   bbox_transform=ax.transAxes,
                   borderpad=0)
    plt.colorbar(im, cax=axins)
    ax.set_xticks(np.arange(len(labels)))
    ax.set_yticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, fontsize=25)
    ax.set_yticklabels(labels, fontsize=25)
    ax.set_ylabel('Predicted labels', fontsize=25)
    ax.set_xlabel('True labels', fontsize=25)
    ## The code below is only used for new version Matplotlib (verison >= 3.1.1)
    bottom, top = ax.get_ylim()
    ax.set_ylim(bottom + 0.5, top - 0.5)
    ####
    
    for i in range(len(labels)):
        for j in range(len(labels)):
            text = ax.text(j, i, pred_array[i, j],
                           ha="center", va="center", color="k", fontsize=35)

    #fig.tight_layout()
    plt.show()

### 2) ROC analysis 

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve

def roc(labels, prob):
    ## Alway ensure you are assigning the right catagory as positive labels.
    y = np.zeros(labels.shape[0])
    y[np.where(labels==1)] = 1
    ##
    fpr, tpr, _ = roc_curve(y, prob[:,1]) # probability of '1'
    auc = roc_auc_score(y, prob[:,1])
    index = np.argmax(tpr-fpr)
    sen = tpr[index]
    spe = 1-fpr[index]
    return fpr, tpr, auc, sen, spe

In [None]:
def roc_plot(labels, prob):
    ## Alway ensure you are assigning the right catagory as positive labels.
    y = np.zeros(labels.shape[0])
    y[np.where(labels==1)] = 1
    ##
    fpr, tpr, _ = roc_curve(y, prob[:,1]) # probability of '1'
    auc = roc_auc_score(y, prob[:,1])
    index = np.argmax(tpr-fpr)
    sen = tpr[index]
    spe = 1-fpr[index]
    fig = plt.figure(figsize=(7,7))
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
             lw=2*lw)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([-0.05, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('1 - Specificity', fontsize=25)
    plt.ylabel('Sensitivity', fontsize=25)
    plt.xticks(fontsize=25)
    plt.yticks(fontsize=25)
    #plt.title('a vs b', fontsize=25)
    plt.text(0.5, 0.3, 'AUC = %0.3f\nSEN = %0.3f\nSPE = %0.3f' % (auc, sen, spe), fontsize=20)
    #plt.legend(loc="lower right", fontsize=20)
    plt.tight_layout()
    plt.show()

### 3) Precision-recall analysis

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score

def pr_plot(labels, prob):
    ## Alway ensure you are assigning the right catagory as positive labels.
    y = np.zeros(labels.shape[0])
    y[np.where(labels==1)] = 1
    ##
    precision, recall, _ = precision_recall_curve(y, prob[:,1]) # probability of '1'
    ave_score = average_precision_score(y, prob[:, 1])
    
    fig = plt.figure(figsize=(7,7))
    lw = 2
    plt.plot(recall, precision, color='deepskyblue',
             lw=lw)
    plt.fill_between(recall, min(precision)*np.ones(len(precision)), precision, fc='skyblue', alpha=0.3)
    plt.plot([0, 1], [ave_score, ave_score], color='red', lw=1.5*lw, linestyle='--')
    plt.xlim([-0.05, 1.05])
    #plt.ylim([-0.05, 1.05])
    plt.xlabel('Recall', fontsize=25)
    plt.ylabel('Precision', fontsize=25)
    plt.xticks(fontsize=25)
    plt.yticks(fontsize=25)
    #plt.title('a vs b', fontsize=25)
    plt.legend(labels = ('Precision-recall (PR) curve', 'Average PR score: %0.2f' % ave_score), 
               loc="lower left", fontsize=20)
    plt.tight_layout()
    plt.show()