In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

In [None]:
def get_prediction_metrics(data_path, num_buckets=10):
    df = pd.read_csv(data_path)
    gold = df['exact_matches'].values.tolist()
    conf = df['confidences'].values.tolist()
    thresholds = np.quantile(conf, np.arange(0,1,1.0/num_buckets)).tolist()
    prediction_metrics = {}
    for metric in ["tp","tn","fp","fn"]:
        prediction_metrics[metric] = np.zeros_like(thresholds)
    for bucket, threshold in enumerate(thresholds):
        for i, conf_score in enumerate(conf):
            em = gold[i]
            if conf_score >= threshold:
                if em >= 0.5:
                    prediction_metrics['tp'][bucket] += 1
                else:
                    prediction_metrics['fp'][bucket] += 1
            else:
                if em < 0.5:
                    prediction_metrics['tn'][bucket] += 1
                else:
                    prediction_metrics['fn'][bucket] += 1
    prediction_metrics['included'] = (prediction_metrics['tp'] + prediction_metrics['fp']) / (
        prediction_metrics['tp'] + prediction_metrics['fp'] + prediction_metrics['tn'] + prediction_metrics['fn'])
    prediction_metrics['prediction_acc'] = (prediction_metrics['tp'] + prediction_metrics['tn']) / (
        prediction_metrics['tp'] + prediction_metrics['fp'] + prediction_metrics['tn'] + prediction_metrics['fn'])
    prediction_metrics['precision'] = prediction_metrics['tp'] / (prediction_metrics['tp'] + prediction_metrics['fp'])
    prediction_metrics['recall'] = prediction_metrics['tp'] / (prediction_metrics['tp'] + prediction_metrics['fn'])
    prediction_metrics['F1'] = 2 * prediction_metrics['tp'] / (2 * prediction_metrics['tp'] + prediction_metrics['fp'] +  + prediction_metrics['fn'])
    
    for i in range(num_buckets):
        if 1 - prediction_metrics['included'][i] < float(i) / num_buckets:
            for metric in prediction_metrics:
                prediction_metrics[metric][i] = 0
    
    return prediction_metrics

Coverage-Precision

In [None]:
baseline_metrics = get_prediction_metrics('../diagrams/baseline_raw.csv')
dropout_metrics = get_prediction_metrics('../diagrams/dropout_2_beam_raw.csv')

In [None]:
fig = go.Figure()
fig.add_trace(go.Bar(x=np.arange(0,1,0.1), y=np.flip(baseline_metrics['precision']), name='Baseline calibration'))
fig.add_trace(go.Bar(x=np.arange(0,1,0.1), y=np.flip(dropout_metrics['precision']), name='Best calibration'))
fig.update_layout(xaxis_title='Fraction of data included',
                  yaxis_title='Precision')
fig.update_yaxes(range=(0,1))
fig.update_layout(xaxis = dict(
        tickmode = 'array',
        tickformat = ',.0%',
        tickvals = np.arange(0,1,0.1)))
fig.show()

Precision-Recall

In [None]:
baseline_metrics = get_prediction_metrics('../diagrams/baseline_raw.csv', num_buckets=100)
dropout_metrics = get_prediction_metrics('../diagrams/dropout_2_beam_raw.csv', num_buckets=100)

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=baseline_metrics['recall'], y= baseline_metrics['precision'], name='Baseline', mode='markers'))
fig.add_trace(go.Scatter(x=dropout_metrics['recall'], y= dropout_metrics['precision'], name='Best Calibrator', mode='markers'))
fig.update_layout(title='Calibration Precision-Recall',
                   xaxis_title='Recall',
                   yaxis_title='Precision')
fig.show()