In [1]:
import numpy as np

from utils.conformal_utils import clustered_conformal, random_split
from utils.experiment_utils import load_dataset

This notebook shows how to apply _Clustered Conformal Prediction_ to a set of softmax scores and labels

## 0) Specify desired coverage level

In [2]:
alpha = 0.1 # Correspond to 90% coverage

## 1) Get conformal scores
* softmax_score: `(num_instances, num_classes)` array
* labels: `(num_instances,)` array

In [3]:
softmax_scores, labels = load_dataset('imagenet')

In [4]:
scores_all = 1 - softmax_scores

## 2) Split into calibration and validation datasets

In [5]:
# Specify size of calibration dataset
n_avg = 30 # Average number of examples per class 
cal_scores_all, cal_labels, val_scores_all, val_labels = random_split(scores_all, labels, n_avg)

## 3) Use the calibration dataset to estimate conformal quantiles

In [6]:
q_hats = clustered_conformal(cal_scores_all, cal_labels, alpha)

n_clustering=12, num_clusters=6
0 of 1000 classes are rare in the clustering set and will be assigned to the null cluster
Cluster sizes: [186, 185, 180, 171, 153, 125]


In [7]:
# You can pass the quantiles into a wrapper to get a prediction set function 
get_pred_set = lambda softmax_vec: np.where(softmax_vec <= q_hats)[0]

## 4) Apply prediction set function to new examples 

You can rerun the following cell to generate prediction sets for different randomly sampled test points

In [8]:
# Get a test softmax vector from the calibration dataset
i = np.random.choice(np.arange(len(val_labels)))
softmax_vec = val_scores_all[i]
true_label = val_labels[i]

print('Prediction set:', get_pred_set(softmax_vec))
print('True label:', true_label)

Prediction set: [433 457 463 529 615 631 638 667 773 804 837 868 898 911 999]
True label: 433


### Evaluation

To compute coverage and set size metrics, you can pass `val_scores_all` and `val_labels` into the call to `clustered_conformal()`:

In [9]:
qhats, preds, class_cov_metrics, set_size_metrics = clustered_conformal(cal_scores_all, cal_labels,
                                                                        alpha,
                                                                        val_scores_all=val_scores_all, 
                                                                        val_labels=val_labels)

n_clustering=12, num_clusters=6
0 of 1000 classes are rare in the clustering set and will be assigned to the null cluster
Cluster sizes: [186, 185, 180, 171, 153, 125]
CLASS COVERAGE GAP: 0.03313341096404464
AVERAGE SET SIZE: 2.808151188147288


Additional metrics can be found in `class_cov_metrics` and `set_size_metrics`