In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from conformal_utils import *

%load_ext autoreload
%autoreload 2

Run this notebook after running `get_posterior_quantiles_oracle.{ipynb, py}`

# Load data

In [2]:
# Load data 
softmax_scores = torch.load('/home/eecs/tiffany_ding/code/SimCLRv2-Pytorch/.cache/logits/imagenet_train_subset_softmax.pt', map_location=torch.device('cpu'))
softmax_scores = softmax_scores.numpy()
labels = torch.load('/home/eecs/tiffany_ding/code/SimCLRv2-Pytorch/.cache/logits/imagenet_train_subset_labels.pt', map_location=torch.device('cpu'))
labels = labels.type(torch.LongTensor).numpy()

In [3]:
scores = 1 - softmax_scores

In [4]:
# Get split of data not used for calibration
_, _, unused_scores, unused_labels = split_X_and_y(scores, labels, 20, num_classes=1000, seed=0)

In [5]:
# Split unused data into 
# 1. Data for estimating conformal adjustment and
# 2. Data for computing coverage

# Select 10 examples per class (10,000 total examples) for 1. and leave the rest for 2. 
scores1, labels1, scores2, labels2 = split_X_and_y(unused_scores, unused_labels, 10, num_classes=1000, seed=0)

### Method 1: Adjust which quantile we take of the posterior distr.

We want to find $\tilde{\alpha}$ such that the $(1-\tilde{\alpha})\%$-quantile of the posterior score distribution achieves marginal coverage of $1-\alpha$

In [6]:
cached_samples = np.load('.cache/cached_samples_06-10-22.npy')

In [7]:
alpha = 0.1
num_classes = 1000

# Restrict search between (1 - alpha) +/- 0.5 to start
quantile_min = (1 - alpha) - .05
quantile_max = (1 - alpha) + .05

# ===== Perform binary search =====
# Convergence criteria: Either (1) marginal coverage is within tol of desired or (2)
# quantile_min and quantile_max differ by less than .001, so there is no need to try 
# to get a more precise estimate
tol = 0.0005

marginal_coverage = 0
while np.abs(marginal_coverage - (1-alpha)) > tol:
    
    quantile_guess = (quantile_min +  quantile_max) / 2
    print(f"\nCurrent quantile guess: {quantile_guess:.4f}")
    
    # 1. Get qhats_k
    qhats = [np.quantile(cached_samples[k,:], quantile_guess, interpolation='higher') for k in range(num_classes)]
    
    # 2. Compute coverage using these qhats
    preds = create_cb_prediction_sets(scores1, qhats)
    marginal_coverage = compute_coverage(labels1, preds)
    print(f"Marginal coverage: {marginal_coverage:.4f}")
    
    if marginal_coverage > 1 - alpha:
        quantile_max = quantile_guess
    else:
        quantile_min = quantile_guess
    print(f"Search range: [{quantile_min}, {quantile_max}]")
        
    if quantile_max - quantile_min < .0001:
        quantile_guess = quantile_max # Conservative estimate, which ensures coverage
        print("Adequate precision reached; stopping early.")
        break

print("\nFINAL QUANTILE:", quantile_guess)


Current quantile guess: 0.9000
Marginal coverage: 0.9050
Search range: [0.85, 0.9]

Current quantile guess: 0.8750
Marginal coverage: 0.8868
Search range: [0.875, 0.9]

Current quantile guess: 0.8875
Marginal coverage: 0.8971
Search range: [0.8875, 0.9]

Current quantile guess: 0.8938
Marginal coverage: 0.9008
Search range: [0.8875, 0.89375]

Current quantile guess: 0.8906
Marginal coverage: 0.8991
Search range: [0.890625, 0.89375]

Current quantile guess: 0.8922
Marginal coverage: 0.8996
Search range: [0.8921875, 0.89375]

FINAL QUANTILE: 0.8921875


#### Save qhats obtained by getting the `quantile_guess` quantile of the posterior score distribution

In [8]:
print(f'Computing the {quantile_guess * 100:.3f}% quantile of the posterior score distribution...')
conformalized_qhats = [np.quantile(cached_samples[k,:], quantile_guess, interpolation='higher') for k in range(num_classes)]


save_to = '.cache/conformalized_qhats.npy'
np.save(save_to, conformalized_qhats)
print(f'Saved conformalized qhats to {save_to}')

Computing the 89.219% quantile of the posterior score distribution...
Saved conformalized qhats to .cache/conformalized_qhats.npy


### Method 2: Apply additive (or multiplicative) offset to $\widehat{q}^{EB}$

In [None]:
# TODO