In [1]:
# HIDDEN
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from ipywidgets import Dropdown, FloatSlider, interact
from sklearn.metrics import (
    accuracy_score,
    auc,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_curve,
)

%matplotlib inline

In [2]:
# HIDDEN
def generate_output(separability, threshold):
    N = 10000
    sigma = 0.1
    class_names = ["negative", "positive"]  # 0 : negative, 1: positive
    pos_pred, pos_gt, neg_pred, neg_gt = generate_data(
        N, sigma, separability=separability)
    # Declare fig size
    fig, ax = plt.subplots(3, 1, figsize=(10, 15))

    # make plots
    plot_pdf(pos_pred, neg_pred, threshold, ax[0])

    # now stack them all together for performance calcs
    preds = np.hstack([pos_pred, neg_pred])
    gt = np.hstack([pos_gt, neg_gt])

    # compute metrics
    fpr, tpr, roc_auc, cm, acc, precision, recall, f1 = compute_metrics(
        gt, preds, threshold)

    # make roc
    plot_roc(fpr, tpr, roc_auc, ax[1])

    # make confusion matrix
    plot_confusion_matrix(
        cm,
        title=f"Confusion matrix - Total num of instances: {len(gt)}",
        classes=class_names,
        ax=ax[2])

    fig.tight_layout()

    # print metrics:
    print("AUC : %0.2f" % (roc_auc))
    print("Precision - tp/(tp + fp) : {:.2f}".format(precision))
    print("Recall (Sensitivity) - tp/(tp + fn) : {:.2f}".format(recall))
    print("Accuracy - (tp + fp)/total : {:.2f}".format(acc))
    print(
        "F1 score - 2 * (precision * recall) / (precision + recall) : {:.2f}".format(f1))


def compute_metrics(y_true, y_pred, threshold):

    fpr, tpr, _ = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    cm = confusion_matrix(y_true, y_pred > threshold)
    acc = accuracy_score(y_true, y_pred > threshold, normalize=True)
    precision = precision_score(y_true, y_pred > threshold, average="binary")
    recall = recall_score(y_true, y_pred > threshold, average="binary")
    f1 = f1_score(y_true, y_pred > threshold, average="binary")
    return fpr, tpr, roc_auc, cm, acc, precision, recall, f1


def generate_data(N, sigma, separability):
    assert separability in ["low", "medium", "high"]
    mu_pos = mu_neg = 0.5
    if separability == "medium":
        mu_pos += 0.1
        mu_neg -= 0.1
    elif separability == "high":
        mu_pos += 0.2
        mu_neg -= 0.2

    pos_pred = np.random.normal(mu_pos, sigma, N)
    pos_gt = np.ones_like(pos_pred)

    # a Gaussian for the negative class
    neg_pred = np.random.normal(mu_neg, sigma, N)
    neg_gt = np.zeros_like(neg_pred)

    return pos_pred, pos_gt, neg_pred, neg_gt


def plot_confusion_matrix(cm, title, classes, ax):
    """
    This function prints and plots the confusion matrix.
    """

    ax.imshow(cm, cmap=plt.get_cmap("coolwarm"))

    # We want to show all ticks...
    ax.set(
        xticks=np.arange(cm.shape[1]),
        yticks=np.arange(cm.shape[0]),
        # ... and label them with the respective list entries
        xticklabels=classes,
        yticklabels=classes,
        title=title,
        ylabel='True label',
        xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(),
             rotation=45,
             ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    left, width = .25, .5
    bottom, height = .25, .5
    right = left + width
    top = bottom + height

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            # add text labels --
            if i == 0 and j == 0:
                text = "True Negative"
                x_cord = left
                y_cord = top
            elif i == 0 and j == 1:
                text = "False Positive"
                x_cord = right
                y_cord = top
            elif i == 1 and j == 0:
                text = "False Negative"
                x_cord = left
                y_cord = bottom
            else:
                text = "True Positive"
                x_cord = right
                y_cord = bottom
            text = f"{text}\n" + format(cm[i, j], 'd')
            ax.text(x=x_cord,
                    y=y_cord,
                    s=text,
                    transform=ax.transAxes,
                    ha="center",
                    va="center",
                    color="white",
                    weight='bold',
                    fontsize=14)

    plt.tight_layout()


def plot_pdf(pos_pred, neg_pred, threshold, ax):
    sns.kdeplot(pos_pred, label='positive', shade=True, color="red", ax=ax)
    sns.kdeplot(neg_pred, label='negative', shade=True, color="green", ax=ax)
    ax.axvline(threshold, label="threshold", linestyle="--", color="black")
    ax.set_title("Probability distribution of data")
    ax.set_ylabel("Counts", fontsize=12)
    ax.set_xlabel('P(X="positive")', fontsize=12)
    ax.legend()


def plot_roc(fpr, tpr, roc_auc, ax):
    ax.plot(fpr, tpr, lw=1, alpha=0.3, label='ROC (AUC = %0.2f)' % (roc_auc))

    ax.plot([0, 1], [0, 1],
            linestyle='--',
            lw=2,
            color='r',
            label='Chance',
            alpha=.8)

    ax.set_xlim([-0.05, 1.00])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate (FPR)', fontsize=12)
    ax.set_ylabel('True Positive Rate (TPR)', fontsize=12)
    ax.set_title('Receiver operating characteristic (ROC) curve')
    ax.legend(loc="lower right")

In [3]:
# HIDDEN
style = {'description_width': 'initial'}
# w = interactive(generate_output,
interact(generate_output,
         separability= Dropdown(
             options=['low', 'medium', 'high'],
             value='medium',
             description='Model performance:',
             style=style,
             disabled=False,
         ),
         threshold= FloatSlider(
             value=0.5,
             min=0.3,
             max=0.7,
             step=0.1,
             description='Probability cutoff:',
             disabled=False,
             style=style,
             continuous_update=False,
             orientation='horizontal',
             readout=True,
             readout_format='.1f',
         ))

interactive(children=(Dropdown(description='Model performance:', index=1, options=('low', 'medium', 'high'), s…

<function __main__.generate_output(separability, threshold)>