In [4]:
import time

import argparse
import joblib
import matplotlib.pyplot as plt
import numpy as np
# import pandas as pd
import pickle
# import torch

from collections import Counter
from scipy import stats, cluster

from utils.conformal_utils import *

In [28]:
# Choose settings

alpha = .1
n_totalcal = 20 # Total number of calibration points (= # clustering examples + # conformal calibration examples)
# NOTE: increasing n_totalcal decreases the number of points we can use to compute coverage

# Choose data

# # Enron - BoW 
# softmax_path = "../class-conditional-conformal-datasets/notebooks/.cache/email_softmax_BoW.npy"
# labels_path = "../class-conditional-conformal-datasets/notebooks/.cache/email_labels_BoW.npy"

# # Enron - GloVe
# softmax_path = "../class-conditional-conformal-datasets/.cache/email_softmax_glove.npy"
# labels_path = "../class-conditional-conformal-datasets/.cache/email_labels_glove.npy"

# # Enron - BERT (n_train=800)
# softmax_path = "../class-conditional-conformal-datasets/notebooks/.cache/email_softmax_bert.npy"
# labels_path = "../class-conditional-conformal-datasets/notebooks/.cache/email_labels_bert.npy"

# Enron - BERT (n_train=500)
softmax_path = "../class-conditional-conformal-datasets/notebooks/.cache/email_softmax_bert_ntrain=500.npy"
labels_path = "../class-conditional-conformal-datasets/notebooks/.cache/email_labels_bert_ntrain=500.npy"


# # CIFAR-100
# softmax_path = "../class-conditional-conformal-datasets/notebooks/.cache/best-cifar100-model-valsoftmax_frac=0.3.npy"
# labels_path = "../class-conditional-conformal-datasets/notebooks/.cache/best-cifar100-model-vallabels_frac=0.3.npy"

In [29]:
## 1. Get data ============================
print('Loading softmax scores and labels...')

softmax_scores = np.load(softmax_path)
labels = np.load(labels_path)

num_classes = labels.max() + 1

Loading softmax scores and labels...


# Option 1: For a single random seed

In [5]:
print('====== SETTINGS =====')
print(f'alpha={alpha}')
print(f'n_totalcal={n_totalcal}')
print('=====================')

for score_function in ['softmax', 'APS', 'RAPS']:
    
    print(f'====== score_function={score_function} ======')
    
    print('Computing conformal score...')
    if score_function == 'softmax':
        scores_all = 1 - softmax_scores
    elif score_function == 'APS':
        scores_all = get_APS_scores_all(softmax_scores, randomize=True)
    elif score_function == 'RAPS': 
        
        # RAPS hyperparameters (currently using ImageNet defaults)
        lmbda = .01 
        kreg = 5
        
        scores_all = get_RAPS_scores_all(softmax_scores, lmbda, kreg, randomize=True)
    else:
        raise Exception('Undefined score function')


    print('Splitting data...')
    # Split into clustering+calibration data and validation data
    totalcal_scores_all, totalcal_labels, val_scores_all, val_labels = split_X_and_y(scores_all, labels, n_totalcal, num_classes=num_classes, seed=0)


## 2. Compute baselines for evaluation ============================

    print('Evaluating baselines...')
    # A) Vanilla conformal
    vanilla_qhat = compute_qhat(totalcal_scores_all, totalcal_labels, alpha=alpha)
    vanilla_preds = create_prediction_sets(val_scores_all, vanilla_qhat)

    marginal_cov = compute_coverage(val_labels, vanilla_preds)
    print(f'Marginal coverage of Vanilla: {marginal_cov*100:.2f}%')
    vanilla_class_specific_cov = compute_class_specific_coverage(val_labels, vanilla_preds)

    # B) Naive class-balanced
    naivecb_qhats = compute_class_specific_qhats(totalcal_scores_all, totalcal_labels, alpha=alpha, default_qhat=np.inf)
    naivecb_preds = create_cb_prediction_sets(val_scores_all, naivecb_qhats)

    naivecb_marginal_cov = compute_coverage(val_labels, naivecb_preds)
    print(f'Marginal coverage of NaiveCC: {naivecb_marginal_cov*100:.2f}%')
    naivecb_class_specific_cov = compute_class_specific_coverage(val_labels, naivecb_preds)

    # CC coverage
    vanilla_l1_dist = np.mean(np.abs(vanilla_class_specific_cov - (1 - alpha)))
    naivecb_l1_dist = np.mean(np.abs(naivecb_class_specific_cov - (1 - alpha)))

    print(f'[Vanilla] Average class-coverage gap: {vanilla_l1_dist*100:.3f}')
    print(f'[NaiveCC] Average class-coverage gap: {naivecb_l1_dist*100:.3f}')
    
    ## Set size
    vanilla_set_sizes = [len(x) for x in vanilla_preds]
    vanilla_set_size_metrics = {'mean': np.mean(vanilla_set_sizes), '[.25, .5, .75, .9] quantiles': np.quantile(vanilla_set_sizes, [.25, .5, .75, .9])}
    naivecb_set_sizes = [len(x) for x in naivecb_preds]
    naivecb_set_size_metrics = {'mean': np.mean(naivecb_set_sizes), '[.25, .5, .75, .9] quantiles': np.quantile(naivecb_set_sizes, [.25, .5, .75, .9])}
    print(f'[Vanilla] set size metrics:', vanilla_set_size_metrics)
    print(f'[NaiveCC] set size metrics:', naivecb_set_size_metrics)

alpha=0.1
n_totalcal=20
Loading softmax scores and labels...
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.04%
Marginal coverage of NaiveCC: 91.72%
[Vanilla] Average class-coverage gap: 1.905
[NaiveCC] Average class-coverage gap: 4.790
[Vanilla] set size metrics: {'mean': 36.560241069747, '[.25, .5, .75, .9] quantiles': array([29., 39., 47., 52.])}
[NaiveCC] set size metrics: {'mean': 45.25335928181305, '[.25, .5, .75, .9] quantiles': array([38., 49., 56., 61.])}
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.88%
Marginal coverage of NaiveCC: 91.65%
[Vanilla] Average class-coverage gap: 1.948
[NaiveCC] Average class-coverage gap: 4.767
[Vanilla] set size metrics: {'mean': 37.31397576746814, '[.25, .5, .75, .9] quantiles': array([30., 41., 48., 53.])}
[NaiveCC] set size metrics: {'mean': 46.7664636825915, '[.25, .5, .75, .9] quantiles': array([40., 50., 57., 62.])}
Computin

In [None]:
# ENRON 

'''
-------- BoW embeddings -------- 

====== SETTINGS =====
n_totalcal = 10
Loading softmax scores and labels...
====== score_function=softmax ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.15%
Marginal coverage of NaiveCC: 91.17%
[Vanilla] L1 distance between desired and realized class-cond. coverage: 6.342
[NaiveCC] L1 distance between desired and realized class-cond. coverage: 7.426
Note: The average magnitude of deviation from the desired coverage is L1 dist/1000
[Vanilla] set size metrics: {'mean': 58.68494018378316, '[.25, .5, .75, .9] quantiles': array([51., 62., 69., 74.])}
[NaiveCC] set size metrics: {'mean': 64.16977402762527, '[.25, .5, .75, .9] quantiles': array([58., 68., 74., 78.])}
====== score_function=APS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.27%
Marginal coverage of NaiveCC: 90.77%
[Vanilla] L1 distance between desired and realized class-cond. coverage: 6.927
[NaiveCC] L1 distance between desired and realized class-cond. coverage: 7.888
Note: The average magnitude of deviation from the desired coverage is L1 dist/1000
[Vanilla] set size metrics: {'mean': 56.52613304051321, '[.25, .5, .75, .9] quantiles': array([56., 59., 60., 61.])}
[NaiveCC] set size metrics: {'mean': 64.58115008957984, '[.25, .5, .75, .9] quantiles': array([62., 66., 70., 73.])}
====== score_function=RAPS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.46%
Marginal coverage of NaiveCC: 90.68%
[Vanilla] L1 distance between desired and realized class-cond. coverage: 6.981
[NaiveCC] L1 distance between desired and realized class-cond. coverage: 8.155
Note: The average magnitude of deviation from the desired coverage is L1 dist/1000
[Vanilla] set size metrics: {'mean': 57.007580188406635, '[.25, .5, .75, .9] quantiles': array([57., 58., 58., 59.])}
[NaiveCC] set size metrics: {'mean': 64.05058775934809, '[.25, .5, .75, .9] quantiles': array([61., 65., 68., 71.])}

Increasing n_totalcal does not have much of an effect

-------- GloVE embeddings -------- 

====== SETTINGS =====
alpha=0.1
n_totalcal=100
=====================
Loading softmax scores and labels...
====== score_function=softmax ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.83%
Marginal coverage of NaiveCC: 89.88%
[Vanilla] Average class-coverage gap: 2.880
[NaiveCC] Average class-coverage gap: 3.269
[Vanilla] set size metrics: {'mean': 40.52667016685662, '[.25, .5, .75, .9] quantiles': array([26., 43., 56., 65.])}
[NaiveCC] set size metrics: {'mean': 42.123614757142306, '[.25, .5, .75, .9] quantiles': array([28., 46., 57., 66.])}
====== score_function=APS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.81%
Marginal coverage of NaiveCC: 89.71%
[Vanilla] Average class-coverage gap: 2.908
[NaiveCC] Average class-coverage gap: 3.269
[Vanilla] set size metrics: {'mean': 45.82951940684586, '[.25, .5, .75, .9] quantiles': array([30., 52., 64., 71.])}
[NaiveCC] set size metrics: {'mean': 46.83192685456711, '[.25, .5, .75, .9] quantiles': array([32., 53., 64., 71.])}
====== score_function=RAPS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.99%
Marginal coverage of NaiveCC: 90.04%
[Vanilla] Average class-coverage gap: 2.876
[NaiveCC] Average class-coverage gap: 3.378
[Vanilla] set size metrics: {'mean': 42.630730814050274, '[.25, .5, .75, .9] quantiles': array([37., 44., 48., 51.])}
[NaiveCC] set size metrics: {'mean': 44.66073427155498, '[.25, .5, .75, .9] quantiles': array([40., 46., 50., 53.])}

-------- BERT embeddings --------

====== SETTINGS =====
alpha=0.1
n_totalcal=100
=====================
Loading softmax scores and labels...
====== score_function=softmax ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.02%
Marginal coverage of NaiveCC: 90.75%
[Vanilla] Average class-coverage gap: 2.042
[NaiveCC] Average class-coverage gap: 2.266
[Vanilla] set size metrics: {'mean': 36.464243077752535, '[.25, .5, .75, .9] quantiles': array([29., 39., 47., 52.])}
[NaiveCC] set size metrics: {'mean': 38.024790439147125, '[.25, .5, .75, .9] quantiles': array([30., 41., 48., 54.])}
====== score_function=APS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.08%
Marginal coverage of NaiveCC: 90.63%
[Vanilla] Average class-coverage gap: 1.991
[NaiveCC] Average class-coverage gap: 2.093
[Vanilla] set size metrics: {'mean': 37.867757795150254, '[.25, .5, .75, .9] quantiles': array([30., 41., 49., 54.])}
[NaiveCC] set size metrics: {'mean': 39.29886137533536, '[.25, .5, .75, .9] quantiles': array([32., 43., 50., 55.])}
====== score_function=RAPS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.30%
Marginal coverage of NaiveCC: 90.74%
[Vanilla] Average class-coverage gap: 1.994
[NaiveCC] Average class-coverage gap: 2.140
[Vanilla] set size metrics: {'mean': 42.574855265015856, '[.25, .5, .75, .9] quantiles': array([40., 43., 46., 48.])}
[NaiveCC] set size metrics: {'mean': 43.94712006264361, '[.25, .5, .75, .9] quantiles': array([41., 44., 47., 50.])}



====== SETTINGS =====
alpha=0.1
n_totalcal=10
=====================
Loading softmax scores and labels...
====== score_function=softmax ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.98%
Marginal coverage of NaiveCC: 90.21%
[Vanilla] Average class-coverage gap: 1.950
[NaiveCC] Average class-coverage gap: 6.689
[Vanilla] set size metrics: {'mean': 36.35996393824424, '[.25, .5, .75, .9] quantiles': array([28., 39., 46., 52.])}
[NaiveCC] set size metrics: {'mean': 47.23285376206754, '[.25, .5, .75, .9] quantiles': array([41., 51., 57., 62.])}
====== score_function=APS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.78%
Marginal coverage of NaiveCC: 89.81%
[Vanilla] Average class-coverage gap: 2.026
[NaiveCC] Average class-coverage gap: 6.570
[Vanilla] set size metrics: {'mean': 37.01091118540501, '[.25, .5, .75, .9] quantiles': array([30., 40., 48., 53.])}
[NaiveCC] set size metrics: {'mean': 48.94149981843908, '[.25, .5, .75, .9] quantiles': array([43., 52., 58., 63.])}
====== score_function=RAPS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.28%
Marginal coverage of NaiveCC: 90.08%
[Vanilla] Average class-coverage gap: 1.925
[NaiveCC] Average class-coverage gap: 6.474
[Vanilla] set size metrics: {'mean': 42.515172232448066, '[.25, .5, .75, .9] quantiles': array([40., 43., 46., 48.])}
[NaiveCC] set size metrics: {'mean': 52.44974518863554, '[.25, .5, .75, .9] quantiles': array([50., 53., 56., 58.])}

====== SETTINGS =====
alpha=0.1
n_totalcal=20
=====================
Loading softmax scores and labels...
====== score_function=softmax ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.04%
Marginal coverage of NaiveCC: 91.72%
[Vanilla] Average class-coverage gap: 1.905
[NaiveCC] Average class-coverage gap: 4.790
[Vanilla] set size metrics: {'mean': 36.560241069747, '[.25, .5, .75, .9] quantiles': array([29., 39., 47., 52.])}
[NaiveCC] set size metrics: {'mean': 45.25335928181305, '[.25, .5, .75, .9] quantiles': array([38., 49., 56., 61.])}
====== score_function=APS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 89.88%
Marginal coverage of NaiveCC: 91.65%
[Vanilla] Average class-coverage gap: 1.948
[NaiveCC] Average class-coverage gap: 4.767
[Vanilla] set size metrics: {'mean': 37.31397576746814, '[.25, .5, .75, .9] quantiles': array([30., 41., 48., 53.])}
[NaiveCC] set size metrics: {'mean': 46.7664636825915, '[.25, .5, .75, .9] quantiles': array([40., 50., 57., 62.])}
====== score_function=RAPS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.43%
Marginal coverage of NaiveCC: 91.86%
[Vanilla] Average class-coverage gap: 1.914
[NaiveCC] Average class-coverage gap: 4.945
[Vanilla] set size metrics: {'mean': 43.11203716491933, '[.25, .5, .75, .9] quantiles': array([41., 44., 46., 48.])}
[NaiveCC] set size metrics: {'mean': 51.28442965660117, '[.25, .5, .75, .9] quantiles': array([48., 52., 55., 57.])}
'''
None


In [None]:
# CIFAR-100
'''
====== SETTINGS =====
alpha=0.1
n_totalcal=10
=====================
Loading softmax scores and labels...
====== score_function=softmax ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 91.85%
Marginal coverage of NaiveCC: 91.09%
[Vanilla] Average class-coverage gap: 4.049
[NaiveCC] Average class-coverage gap: 8.071
[Vanilla] set size metrics: {'mean': 8.02764705882353, '[.25, .5, .75, .9] quantiles': array([ 1.,  4., 11., 21.])}
[NaiveCC] set size metrics: {'mean': 24.23029411764706, '[.25, .5, .75, .9] quantiles': array([14., 21., 32., 43.])}
====== score_function=APS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 91.95%
Marginal coverage of NaiveCC: 90.66%
[Vanilla] Average class-coverage gap: 3.825
[NaiveCC] Average class-coverage gap: 8.175
[Vanilla] set size metrics: {'mean': 11.498117647058823, '[.25, .5, .75, .9] quantiles': array([ 1.,  4., 16., 35.])}
[NaiveCC] set size metrics: {'mean': 26.020470588235295, '[.25, .5, .75, .9] quantiles': array([13., 23., 36., 49.])}
====== score_function=RAPS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 91.44%
Marginal coverage of NaiveCC: 90.67%
[Vanilla] Average class-coverage gap: 4.715
[NaiveCC] Average class-coverage gap: 8.043
[Vanilla] set size metrics: {'mean': 7.58535294117647, '[.25, .5, .75, .9] quantiles': array([7., 7., 8., 9.])}
[NaiveCC] set size metrics: {'mean': 18.80641176470588, '[.25, .5, .75, .9] quantiles': array([17., 19., 21., 22.])}




====== SETTINGS =====
alpha=0.1
n_totalcal=20
=====================
Loading softmax scores and labels...
====== score_function=softmax ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.75%
Marginal coverage of NaiveCC: 91.05%
[Vanilla] Average class-coverage gap: 3.884
[NaiveCC] Average class-coverage gap: 5.392
[Vanilla] set size metrics: {'mean': 6.823625, '[.25, .5, .75, .9] quantiles': array([ 1.,  3.,  9., 18.])}
[NaiveCC] set size metrics: {'mean': 12.7353125, '[.25, .5, .75, .9] quantiles': array([ 4.,  9., 18., 29.])}
====== score_function=APS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.60%
Marginal coverage of NaiveCC: 91.07%
[Vanilla] Average class-coverage gap: 3.911
[NaiveCC] Average class-coverage gap: 5.007
[Vanilla] set size metrics: {'mean': 9.5258125, '[.25, .5, .75, .9] quantiles': array([ 1.,  3., 12., 29.])}
[NaiveCC] set size metrics: {'mean': 15.344875, '[.25, .5, .75, .9] quantiles': array([ 4., 10., 23., 38.])}
====== score_function=RAPS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.78%
Marginal coverage of NaiveCC: 91.59%
[Vanilla] Average class-coverage gap: 4.786
[NaiveCC] Average class-coverage gap: 5.196
[Vanilla] set size metrics: {'mean': 6.7675, '[.25, .5, .75, .9] quantiles': array([6., 6., 7., 8.])}
[NaiveCC] set size metrics: {'mean': 12.8803125, '[.25, .5, .75, .9] quantiles': array([11., 13., 15., 16.])}




====== SETTINGS =====
alpha=0.1
n_totalcal=100
=====================
Loading softmax scores and labels...
====== score_function=softmax ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.30%
Marginal coverage of NaiveCC: 90.39%
[Vanilla] Average class-coverage gap: 3.978
[NaiveCC] Average class-coverage gap: 3.180
[Vanilla] set size metrics: {'mean': 6.286875, '[.25, .5, .75, .9] quantiles': array([ 1.,  3.,  8., 17.])}
[NaiveCC] set size metrics: {'mean': 8.093625, '[.25, .5, .75, .9] quantiles': array([ 1.,  4., 12., 22.])}
====== score_function=APS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.20%
Marginal coverage of NaiveCC: 90.51%
[Vanilla] Average class-coverage gap: 4.018
[NaiveCC] Average class-coverage gap: 2.944
[Vanilla] set size metrics: {'mean': 8.964375, '[.25, .5, .75, .9] quantiles': array([ 1.,  3., 11., 28.])}
[NaiveCC] set size metrics: {'mean': 10.829125, '[.25, .5, .75, .9] quantiles': array([ 1. ,  4.5, 16. , 31. ])}
====== score_function=RAPS ======
Computing conformal score...
Splitting data...
Evaluating baselines...
Marginal coverage of Vanilla: 90.59%
Marginal coverage of NaiveCC: 90.54%
[Vanilla] Average class-coverage gap: 5.298
[NaiveCC] Average class-coverage gap: 3.236
[Vanilla] set size metrics: {'mean': 6.487625, '[.25, .5, .75, .9] quantiles': array([6., 6., 6., 8.])}
[NaiveCC] set size metrics: {'mean': 9.014125, '[.25, .5, .75, .9] quantiles': array([ 7.,  9., 11., 12.])}
'''
None

BoW embeddings yield average set size of 60. GloVe embeddings yield average set size in the 40s. Better, but still too large to be practically useful I think

# Option 2: Averaged over multiple random seeds

In [35]:
def _print_metrics(metrics):
    for m in ['marginal_cov', 'avg_class_cov_gap', 'avg_set_size']:
        arr = metrics[m]
        if 'cov' in m: # scale coverage metrics to % scale
            arr = arr * 100

        print(f'{m}: {np.mean(arr)} ({np.std(arr) / np.sqrt(len(arr)-1)})')

def compute_baselines(scores_all, labels, 
                      n_totalcal, alpha,
                      seeds=[0,1,2,3,4]):
    
    vanilla_metrics = {'marginal_cov': np.zeros((len(seeds),)),
                       'avg_class_cov_gap': np.zeros((len(seeds),)),
                       'avg_set_size': np.zeros((len(seeds),))}
    
    naiveCC_metrics = {'marginal_cov': np.zeros((len(seeds),)),
                       'avg_class_cov_gap': np.zeros((len(seeds),)),
                       'avg_set_size': np.zeros((len(seeds),))}

    for i, seed in enumerate(seeds):
        # Split into clustering+calibration data and validation data
        totalcal_scores_all, totalcal_labels, val_scores_all, val_labels = split_X_and_y(scores_all, labels, n_totalcal, num_classes=num_classes, seed=seed)

        # A) Vanilla conformal
        vanilla_qhat = compute_qhat(totalcal_scores_all, totalcal_labels, alpha=alpha)
        vanilla_preds = create_prediction_sets(val_scores_all, vanilla_qhat)

        vanilla_metrics['marginal_cov'][i] = compute_coverage(val_labels, vanilla_preds)
        
        vanilla_class_specific_cov = compute_class_specific_coverage(val_labels, vanilla_preds)
        vanilla_metrics['avg_class_cov_gap'][i] = np.mean(np.abs(vanilla_class_specific_cov - (1 - alpha)))
        
        vanilla_metrics['avg_set_size'][i] = compute_avg_set_size(vanilla_preds)
        
        # B) Naive class-conditional
        naivecb_qhats = compute_class_specific_qhats(totalcal_scores_all, totalcal_labels, alpha=alpha, default_qhat=np.inf)
        naivecb_preds = create_cb_prediction_sets(val_scores_all, naivecb_qhats)

        naiveCC_metrics['marginal_cov'][i] = compute_coverage(val_labels, naivecb_preds)
        
        naivecb_class_specific_cov = compute_class_specific_coverage(val_labels, naivecb_preds)
        naiveCC_metrics['avg_class_cov_gap'][i] = np.mean(np.abs(naivecb_class_specific_cov - (1 - alpha)))
        
        naiveCC_metrics['avg_set_size'][i] = compute_avg_set_size(naivecb_preds)

    print('===== VANILLA =====')
    _print_metrics(vanilla_metrics)
    
    print('===== NAIVECC =====')
    _print_metrics(naiveCC_metrics)
    
    print(f'Metrics computed by averaging over {len(seeds)} seeds (SE in parentheses)')
    
    return vanilla_metrics, naiveCC_metrics
  

In [34]:
n_totalcal = 200

print(f'n_totalcal={n_totalcal}')
for score_function in ['softmax', 'APS']:
# for score_function in ['softmax', 'APS', 'RAPS']:
    
    print(f'\n====== score_function={score_function} ======')
    
    if score_function == 'softmax':
        scores_all = 1 - softmax_scores
    elif score_function == 'APS':
        scores_all = get_APS_scores_all(softmax_scores, randomize=True)
    elif score_function == 'RAPS': 
        
        # RAPS hyperparameters (currently using ImageNet defaults)
        lmbda = .01 
        kreg = 5
        
        scores_all = get_RAPS_scores_all(softmax_scores, lmbda, kreg, randomize=True)
    else:
        raise Exception('Undefined score function')


    compute_baselines(scores_all, labels, 
                      n_totalcal, alpha,
                      seeds=[0,1,2,3,4])

n_totalcal=200

totalcal_scores_all (21800, 109)
totalcal_scores_all (21800, 109)
totalcal_scores_all (21800, 109)
totalcal_scores_all (21800, 109)
totalcal_scores_all (21800, 109)
===== VANILLA =====
marginal_cov: 90.53391081090811 (0.035722341724913266)
avg_class_cov_gap: 1.878680373219605 (0.01579572727085521)
avg_set_size: 40.12694892682751 (0.1050872306236167)
===== NAIVECC =====
marginal_cov: 90.40683375124472 (0.13035416242013262)
avg_class_cov_gap: 1.7385439147702335 (0.07593210189380016)
avg_set_size: 40.68802821801778 (0.13337690811771402)
Metrics computed by averaging over 5 seeds (SE in parentheses)

totalcal_scores_all (21800, 109)
totalcal_scores_all (21800, 109)
totalcal_scores_all (21800, 109)
totalcal_scores_all (21800, 109)
totalcal_scores_all (21800, 109)
===== VANILLA =====
marginal_cov: 90.43265227771711 (0.029220405760522677)
avg_class_cov_gap: 1.8460585995966994 (0.015366024763583662)
avg_set_size: 41.1116536094348 (0.0910296155735499)
===== NAIVECC =====
margina

In [None]:
# Enron
'''
-------- BERT embeddings (n_train=500) --------

n_totalcal=150
====== score_function=softmax ======
===== VANILLA =====
marginal_cov: 90.48680143532171 (0.06546761914274572)
avg_class_cov_gap: 1.8864276082729865 (0.020494274010293245)
avg_set_size: 40.01368482135455 (0.18951713861790687)
===== NAIVECC =====
marginal_cov: 90.44937885624282 (0.14439435169191045)
avg_class_cov_gap: 2.003907462032149 (0.07422164082620206)
avg_set_size: 40.97450006513424 (0.17955973624546587)
Metrics computed by averaging over 5 seeds (SE in parentheses)

====== score_function=APS ======
===== VANILLA =====
marginal_cov: 90.40617709406568 (0.06082681696194567)
avg_class_cov_gap: 1.8409886696364341 (0.019554290076265652)
avg_set_size: 41.048725619071305 (0.1943012525460151)
===== NAIVECC =====
marginal_cov: 90.37259151360121 (0.1665021691762598)
avg_class_cov_gap: 2.0127667738410766 (0.06827721108427805)
avg_set_size: 42.27086794329769 (0.2068246336286253)
Metrics computed by averaging over 5 seeds (SE in parentheses)


n_totalcal=200

====== score_function=softmax ======
===== VANILLA =====
marginal_cov: 90.53391081090811 (0.035722341724913266)
avg_class_cov_gap: 1.878680373219605 (0.01579572727085521)
avg_set_size: 40.12694892682751 (0.1050872306236167)
===== NAIVECC =====
marginal_cov: 90.40683375124472 (0.13035416242013262)
avg_class_cov_gap: 1.7385439147702335 (0.07593210189380016)
avg_set_size: 40.68802821801778 (0.13337690811771402)
Metrics computed by averaging over 5 seeds (SE in parentheses)

====== score_function=APS ======
===== VANILLA =====
marginal_cov: 90.43265227771711 (0.029220405760522677)
avg_class_cov_gap: 1.8460585995966994 (0.015366024763583662)
avg_set_size: 41.1116536094348 (0.0910296155735499)
===== NAIVECC =====
marginal_cov: 90.40400235150148 (0.15636108684207367)
avg_class_cov_gap: 1.7599595240594639 (0.08369940669304853)
avg_set_size: 42.078498878237816 (0.13106139621061055)
Metrics computed by averaging over 5 seeds (SE in parentheses)


n_totalcal=250
====== score_function=softmax ======
===== VANILLA =====
marginal_cov: 90.54414607164999 (0.03198213906457043)
avg_class_cov_gap: 1.8918563478579826 (0.024718925471561848)
avg_set_size: 40.13315082481371 (0.09057867602856831)
===== NAIVECC =====
marginal_cov: 90.29848895588432 (0.13221790362743965)
avg_class_cov_gap: 1.597103402545033 (0.07103903560778761)
avg_set_size: 40.45930586790824 (0.03808188095911403)
Metrics computed by averaging over 5 seeds (SE in parentheses)

====== score_function=APS ======
===== VANILLA =====
marginal_cov: 90.45073607177156 (0.027157756066156558)
avg_class_cov_gap: 1.8590334345394144 (0.02491780742926131)
avg_set_size: 41.15158386112496 (0.0842572645200043)
===== NAIVECC =====
marginal_cov: 90.2974678158544 (0.13682646964891795)
avg_class_cov_gap: 1.6122227877184165 (0.06733109627353673)
avg_set_size: 41.7284881049343 (0.03126805994310151)
Metrics computed by averaging over 5 seeds (SE in parentheses)


n_totalcal=350
====== score_function=softmax ======
===== VANILLA =====
marginal_cov: 90.5784393301782 (0.04856066928977158)
avg_class_cov_gap: 1.9194896851513779 (0.02613888607229236)
avg_set_size: 40.17791635968582 (0.13420773892829138)
===== NAIVECC =====
marginal_cov: 90.33713365217719 (0.08424080682814071)
avg_class_cov_gap: 1.4763081884176832 (0.0361494938283182)
avg_set_size: 40.3149880745745 (0.1701395810443179)
Metrics computed by averaging over 5 seeds (SE in parentheses)

====== score_function=APS ======
===== VANILLA =====
marginal_cov: 90.50166706209963 (0.03825253789236101)
avg_class_cov_gap: 1.8954759676106552 (0.023244763499857025)
avg_set_size: 41.26287184225971 (0.11392421319309358)
===== NAIVECC =====
marginal_cov: 90.39412594747819 (0.07206572799085763)
avg_class_cov_gap: 1.459135447427874 (0.042657744118414324)
avg_set_size: 41.58059814437881 (0.17015187221361716)
Metrics computed by averaging over 5 seeds (SE in parentheses)


n_totalcal=450

====== score_function=softmax ======
===== VANILLA =====
marginal_cov: 90.53557720696783 (0.045130499266916126)
avg_class_cov_gap: 2.0301959349388503 (0.04563000677334334)
avg_set_size: 40.01349045583497 (0.12183421498921586)
===== NAIVECC =====
marginal_cov: 90.33753096879374 (0.0759855366242013)
avg_class_cov_gap: 1.3954808721815908 (0.04947208784067652)
avg_set_size: 40.04415809809886 (0.071232814498573)
Metrics computed by averaging over 5 seeds (SE in parentheses)

====== score_function=APS ======
===== VANILLA =====
marginal_cov: 90.46805560904222 (0.022883326352460997)
avg_class_cov_gap: 1.9874550823350017 (0.0528002891240556)
avg_set_size: 41.11355284271062 (0.065336725156444)
===== NAIVECC =====
marginal_cov: 90.34323051051976 (0.08540490346622083)
avg_class_cov_gap: 1.4173910806438221 (0.06056472457303975)
avg_set_size: 41.306360637219036 (0.06192332104373804)
Metrics computed by averaging over 5 seeds (SE in parentheses)

-------- BERT embeddings (n_train=800) --------

n_totalcal=10
====== score_function=softmax ======
Computing conformal score...
===== VANILLA =====
marginal_cov: 90.47393661645567 (0.4939775937453174)
avg_class_cov_gap: 2.143065176277878 (0.15625823234502592)
avg_set_size: 38.03896773224147 (1.664616962735908)
===== NAIVECC =====
marginal_cov: 91.59074915793296 (0.5967923727253125)
avg_class_cov_gap: 6.396983229315494 (0.09386008738423558)
avg_set_size: 50.66145223695578 (1.4228016757195918)
Metrics computed by averaging over 5 seeds (SE in parentheses)
====== score_function=APS ======
Computing conformal score...
===== VANILLA =====
marginal_cov: 90.39149543593405 (0.46440117702353395)
avg_class_cov_gap: 2.101091324690293 (0.13627933664752462)
avg_set_size: 39.087830910434114 (1.6217297759774798)
===== NAIVECC =====
marginal_cov: 91.57952994503087 (0.6635380440657627)
avg_class_cov_gap: 6.418561144897946 (0.14746044963851468)
avg_set_size: 53.04136984586104 (1.60231364953857)
Metrics computed by averaging over 5 seeds (SE in parentheses)


n_totalcal=20
====== score_function=softmax ======
Computing conformal score...
===== VANILLA =====
marginal_cov: 90.3333040366627 (0.2513511269638004)
avg_class_cov_gap: 1.9590018829431375 (0.03394895837019498)
avg_set_size: 37.45009931571348 (0.7763464587874087)
===== NAIVECC =====
marginal_cov: 91.52715173582774 (0.6929364859753497)
avg_class_cov_gap: 4.8809271589656875 (0.19543907601683008)
avg_set_size: 45.03143901060959 (0.8508284410242934)
Metrics computed by averaging over 5 seeds (SE in parentheses)

====== score_function=APS ======
Computing conformal score...
===== VANILLA =====
marginal_cov: 90.22964404545169 (0.20972844895552395)
avg_class_cov_gap: 1.9435673371507396 (0.01760926098742099)
avg_set_size: 38.397608387218284 (0.6634853990448387)
===== NAIVECC =====
marginal_cov: 91.74647498273589 (0.7412457749811378)
avg_class_cov_gap: 4.7745902308436055 (0.13926866196289903)
avg_set_size: 47.17858396635068 (0.8150853014963166)
Metrics computed by averaging over 5 seeds (SE in parentheses)

n_totalcal=100
====== score_function=softmax ======
totalcal_scores_all (10900, 109)
totalcal_scores_all (10900, 109)
totalcal_scores_all (10900, 109)
totalcal_scores_all (10900, 109)
totalcal_scores_all (10900, 109)
===== VANILLA =====
marginal_cov: 89.97075775663983 (0.060796857264701756)
avg_class_cov_gap: 1.9758831687927974 (0.03988509461193606)
avg_set_size: 36.31398210549287 (0.18640371125937635)
===== NAIVECC =====
marginal_cov: 89.87823006123156 (0.4011024992654121)
avg_class_cov_gap: 2.445312623668701 (0.09540599687209772)
avg_set_size: 37.82148932619607 (0.13568041738742495)
Metrics computed by averaging over 5 seeds (SE in parentheses)

====== score_function=APS ======
totalcal_scores_all (10900, 109)
totalcal_scores_all (10900, 109)
totalcal_scores_all (10900, 109)
totalcal_scores_all (10900, 109)
totalcal_scores_all (10900, 109)
===== VANILLA =====
marginal_cov: 89.94056558965868 (0.050386841580239726)
avg_class_cov_gap: 1.9610886815877895 (0.029446807292877387)
avg_set_size: 37.46327672302024 (0.13840142480222611)
===== NAIVECC =====
marginal_cov: 90.03011514614704 (0.285851877594036)
avg_class_cov_gap: 2.368120626939109 (0.12159378978869739)
avg_set_size: 39.283733456566665 (0.09272364050193399)
Metrics computed by averaging over 5 seeds (SE in parentheses)


n_totalcal=150
====== score_function=softmax ======
totalcal_scores_all (16350, 109)
totalcal_scores_all (16350, 109)
totalcal_scores_all (16350, 109)
totalcal_scores_all (16350, 109)
totalcal_scores_all (16350, 109)
===== VANILLA =====
marginal_cov: 89.93071304891227 (0.030049923042817436)
avg_class_cov_gap: 2.0601651918713704 (0.054840613257038685)
avg_set_size: 36.17820676725991 (0.09479708577991632)
===== NAIVECC =====
marginal_cov: 89.88665685904364 (0.1921669521635439)
avg_class_cov_gap: 2.2472081561693997 (0.08739539086019157)
avg_set_size: 37.03277395164755 (0.12383294351543024)
Metrics computed by averaging over 5 seeds (SE in parentheses)

====== score_function=APS ======
totalcal_scores_all (16350, 109)
totalcal_scores_all (16350, 109)
totalcal_scores_all (16350, 109)
totalcal_scores_all (16350, 109)
totalcal_scores_all (16350, 109)
===== VANILLA =====
marginal_cov: 89.89790524794626 (0.026680514149285735)
avg_class_cov_gap: 2.0446571334058015 (0.05966453518722999)
avg_set_size: 37.344325161760686 (0.07412850793050112)
===== NAIVECC =====
marginal_cov: 89.98919425603103 (0.17317044974703308)
avg_class_cov_gap: 2.208967897189223 (0.09215286872510568)
avg_set_size: 38.50458397885719 (0.10951960533455299)
Metrics computed by averaging over 5 seeds (SE in parentheses)
'''