In [1]:
# import glob # For getting file names
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
# import seaborn as sns
# import torch

from collections import Counter
# from gap_statistic import OptimalK
from scipy import stats, cluster
from sklearn.cluster import KMeans
# from yellowbrick.cluster import KElbowVisualizer

from utils.clustering_utils import *
from utils.conformal_utils import *

%load_ext autoreload
%autoreload 2

## 0) Set up and load data

In [2]:
alpha = .1
n_totalcal = 10 # Total number of calibration points (= # clustering examples + # conformal calibration examples)


# # Enron - BERT (n_train=500)
# softmax_path = "../class-conditional-conformal-datasets/notebooks/.cache/email_softmax_bert_ntrain=500.npy"
# labels_path = "../class-conditional-conformal-datasets/notebooks/.cache/email_labels_bert_ntrain=500.npy"

# ImageNet
softmax_path = '/home/tding/data/finetuned_imagenet/imagenet_train_subset_softmax.npy'
labels_path = '/home/tding/data/finetuned_imagenet/imagenet_train_subset_labels.npy'

In [3]:
softmax_scores = np.load(softmax_path)
labels = np.load(labels_path)

num_classes = labels.max() + 1

In [13]:
# score_function = 'softmax'

# if score_function == 'softmax':
#     scores_all = 1 - softmax_scores

score_function = 'APS'
APS_path = '/home/tding/data/finetuned_imagenet/imagenet_train_subset_APS.npy'
scores_all = np.load(APS_path)

In [23]:
cts = Counter(totalcal_labels)
min(cts.values())

10

In [14]:
# Split into clustering+calibration data and validation data
totalcal_scores_all, totalcal_labels, val_scores_all, val_labels = split_X_and_y(scores_all, labels, n_totalcal, num_classes=num_classes, seed=7)

## 1) Generate candidate parameter values 

**Option 1**: Use heuristic based on $K$ and $n_{totalcal}$.

**Option 2**: Use gap statistic? [Not yet implemented, partly because I don't have a good way of choosing n_clustering]

In [15]:
def get_clustering_parameters_v1(num_classes, n_totalcal):
    '''
    Returns a guess of good values for num_clusters and n_clustering based solely 
    on the number of classes and the number of examples per class
    
    Output 
    '''
    # Alias for convenience
    K = num_classes
    N = n_totalcal
    
    n_clustering = int(N*K/(75+K))
    num_clusters = int(np.floor(n_clustering / 2))
    
#     # Solve 
    
#     # ensure that n_clustering is at least 5 
    
#     # The number of calibration examples we want, on average, per cluster
#     min_examples_per_cluster = 150
    
#     n = Counter(totalcal_labels).values() # Unordered
#     n_min = min(n)
#     K = len(n)
    
#     total_n = len(totalcal_labels)
    
#     # Start by allocating 5 points for clustering
#     n_clustering = 5 
    
#     num_clusters = (total_n - n_clustering*K) / min_examples_per_cluster
    
    
    return n_clustering, num_clusters

In [16]:
n_clustering, num_clusters = get_clustering_parameters_v1(num_classes, n_totalcal)

print('Proposed n_clustering:', n_clustering)
print('Proposed num_clusters:', num_clusters)

Proposed n_clustering: 9
Proposed num_clusters: 4


In [17]:
# Split data between clustering and calibration
scores1_all, labels1, scores2_all, labels2 = split_X_and_y(totalcal_scores_all, 
                                                       totalcal_labels, 
                                                       n_clustering, 
                                                       num_classes=num_classes, 
                                                       seed=0)

## 2) Test null hypothesis that there is one cluster 

In [18]:
pval_threshold = .01

pval = test_one_cluster_null(scores1_all, labels1, num_classes, num_clusters, 
                            num_trials=100, seed=0, print_results=True)

Observed metric: 0.9484328961022088
Metric under null: [0.94044854 0.99953304 0.9440248  0.94691334 0.95568019 0.93357114
 1.0053342  0.94352293 0.97150637 0.95524301 0.99726252 0.91334101
 1.01489484 1.03453568 0.94045151 0.9220946  0.91857114 0.97129557
 0.99701779 0.99809679 0.98997127 0.99425791 1.07064511 0.9436496
 0.96758005 0.94825179 0.98665192 0.98911945 0.96384549 0.98513256
 0.94008082 0.99417501 0.98982456 0.96830999 0.9272792  0.94145872
 1.02272391 1.01526252 0.9317927  0.97085659 0.9884877  0.98236891
 0.98069593 0.93640098 0.96083897 0.94133421 0.95868621 0.99174661
 0.97746185 0.95678082 0.95778038 0.92047147 0.92785707 0.93717537
 0.92429454 0.99299401 1.01012032 0.91538481 1.02972787 0.95350313
 0.95113288 0.93343394 0.96553386 1.03157994 0.91949065 0.92586229
 0.95173162 0.99641311 0.94361593 0.94676831 0.94620846 0.95982645
 0.92273011 1.01728438 0.98915332 0.97709663 0.96340258 1.02469157
 0.94653578 0.93802999 0.95006118 0.94874262 0.94130528 0.9407825
 0.971577

## 3) If null hypothesis is rejected, run clustered conformal.
Else, run standard conformal (only on data not used for clustering?? This requires throwing away a lot of data. For now, I run it on all of the total calibration data)

In [24]:
if pval < pval_threshold: 
    print(f'p={pval} for one cluster null hypothesis, so running Clustered Conformal')
    # Run clustered conformal and return prediction sets
    qhats, preds, coverage_metrics, set_size_metrics = clustered_conformal(totalcal_scores_all, totalcal_labels,
                                                                alpha,
                                                                n_clustering, num_clusters,
                                                                val_scores_all=val_scores_all, val_labels=val_labels)
else:
    print(f'p={pval} for one cluster null hypothesis, so running Standard Conformal')
    # Run Standard Conformal and return prediction sets 
    standard_qhat = compute_qhat(totalcal_scores_all, totalcal_labels, alpha=alpha)
    standard_preds = create_prediction_sets(val_scores_all, standard_qhat)
    
    coverage_metrics, set_size_metrics = compute_all_metrics(val_labels, standard_preds, alpha)

p=0.61 for one cluster null hypothesis, so running Standard Conformal


In [20]:
coverage_metrics, set_size_metrics

({'mean_class_cov_gap': 0.025917968400063357,
  'undercov_gap': 0.03492569792924323,
  'overcov_gap': 0.021024880754582934,
  'marginal_cov': 0.9015214544232935,
  'raw_class_coverages': array([0.91783217, 0.92109777, 0.88245462, 0.89234651, 0.91072961,
         0.87008547, 0.92294666, 0.90691716, 0.882146  , 0.92215569,
         0.91760624, 0.93281654, 0.94132873, 0.93431288, 0.90641248,
         0.92602263, 0.94961571, 0.92682927, 0.93214589, 0.92521739,
         0.88638335, 0.94691781, 0.9226087 , 0.95697074, 0.92048401,
         0.91184097, 0.84982935, 0.89132266, 0.88644068, 0.88870008,
         0.91551724, 0.91498685, 0.88538933, 0.90204429, 0.86284722,
         0.83678756, 0.84488735, 0.91364421, 0.89852559, 0.93967715,
         0.90948276, 0.90616855, 0.90350877, 0.90348259, 0.9012766 ,
         0.90831919, 0.74137931, 0.92851064, 0.93658955, 0.84497445,
         0.92041522, 0.93097345, 0.84687767, 0.9021645 , 0.86502547,
         0.81581234, 0.92307692, 0.96397942, 0.91644909,