In [2]:
import matplotlib.pyplot as plt

In [3]:
def reliability_curve(y_true, y_score, pred_labels, bins=5, normalize=False): 
    bin_width = 1.0 / bins
    bin_centers = np.linspace(0, 1.0 - bin_width, bins) + bin_width / 2

    y_score_bin_mean = np.empty(bins)
    empirical_prob_pos = np.empty(bins)
    empirical_acc_pos = np.empty(bins)
    count_bins = np.empty(bins)

    for i, threshold in enumerate(bin_centers):
        # determine all samples where y_score falls into the i-th bin
        bin_idx = np.logical_and(threshold - bin_width / 2 < y_score,
                                  y_score <= threshold + bin_width / 2)
        # Store mean y_score and mean empirical probability of positive class
        y_score_bin_mean[i] = y_score[bin_idx].mean()
        total_bin_preds = bin_idx.sum()
        empirical_acc_pos[i] = (y_true[bin_idx]==pred_labels[bin_idx]).sum()/total_bin_preds
        count_bins[i] = bin_idx.sum()

    ## Plotting Stuff
    plt.figure(0, figsize=(8, 8))
    x = bin_centers
    y = np.nan_to_num(y_score_bin_mean)
    y1 = np.nan_to_num(empirical_acc_pos)  
    plt.plot([0.0, 1.0], [0.0, 1.0], 'k', label="Perfect Calibration")
    plt.bar(x, y, width=0.1, align='center', alpha=0.1, ec='black')
    plt.bar(x, y1, width=0.1, align='center', alpha=0.7, ec='black')
    plt.xlim(0,1)
    # Frescurites
    for i in range(len(y)):
        plt.hlines(y[i],0,x[i], linestyles='dashed') # Here you are drawing the horizontal lines
    for i in range(len(bin_centers)):
        if y[i]>0:
            plt.annotate("{:.3f}".format(y[i]),  xy=(bin_centers[i]- bin_width / 2, y[i]- 0.03), weight = 'bold', textcoords='offset points')

    ## Calculating ECE
    gap_weighted = np.absolute(np.nan_to_num(y_score_bin_mean)-np.nan_to_num(empirical_acc_pos))*count_bins
    ece = gap_weighted.sum()

    ## Frescurites
    bbox_props = dict(boxstyle="round", fc="lightgrey", ec="brown", lw=2)
    plt.text(0.2, 0.9, "ECE: {:.2f}".format(ece), ha="center", va="center", size=20, weight = 'bold', bbox=bbox_props)

    plt.title("Reliability Diagram", size=20)
    plt.ylabel("Accuracy (P[y]",  size=18)
    plt.xlabel("Confidence",  size=18)
    return ece