In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import sys, os
import pickle
from torch import nn
import utils.trainutils as tutils
import utils.datautils as dutils
import utils.uqutils as uqutils
import models
from tqdm import tqdm_notebook
import torch.nn.functional as F
from scipy.special import comb

# SETUP GPU
torch.backends.cudnn.benchmark = True
device = torch.device("cuda:0")
base = "/home/data/"

In [2]:
def res34(num_class):
    model = torchvision.models.resnet34()
    model.fc = nn.Linear(512, num_class)
    return model

In [18]:
loader_dict, num_class = dutils.return_loaders(base=base, dataset='CIFAR10', start=1000, end=1500, 
                                               train_shuffle=False, valid_shuffle=False)
np.random.seed(1)
model_file_pattern = 'CIFAR10_ntrain-1000_MixUpAlpha-0.5_id-*.model'
model = models.FastResNet().to(device)
test_probs, targets, model_files = tutils.infer_ensemble(model_file_pattern=model_file_pattern, model=model, 
                                                         dataloader=loader_dict['test'], evalmode=False)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/1 [00:00<?, ?it/s]

Number of files found: 1


100%|██████████| 1/1 [00:00<00:00,  1.16it/s]


In [19]:
uqutils.get_all_scores(np.mean(test_probs, axis=0), targets)

(0.6609, 0.04309134163558483, 1.0365074, 0.46303225807161197)

In [None]:
all_result = []
for row in random_samples:
    temp = []
    for perm in row:
        test_mean_probs = np.mean(probs[perm,:,:], axis=0)
        temp.append(get_acc_ece_nll(test_mean_probs, targets))
    all_result.append((np.mean(temp, axis=0), np.std(temp, axis=0)))

In [None]:
for row in all_result:
    print(row[0], row[1])

In [None]:
c10logits, c10targets = ens_model.forward(loader1['test'], eval=True, return_target=True)
c10probs = F.softmax(c10logits, dim=2)
c10probs, c10targets = c10probs.cpu().numpy(), c10targets.cpu().numpy()

c100logits, c100targets = ens_model.forward(loader2['test'], eval=True, return_target=True)
c100probs = F.softmax(c100logits, dim=2)
c100probs, c100targets = c100probs.cpu().numpy(), c100targets.cpu().numpy()

In [None]:
c100probs = F.softmax(c100logits, dim=2)
c100probs = c100probs.cpu().numpy()
interesting_labels = [0, 1, 16, 17, 20, 21, 29, 39, 40, 49, 57, 71, 72, 73, 76]
c100probs = c100probs[:,[item in interesting_labels for item in c100targets]]

c10prob_mean, c100prob_mean = np.mean(c10probs, axis=0), np.mean(c100probs, axis=0)
c10entrop_mean = -np.sum(c10prob_mean * np.log(c10prob_mean + 1e-5), axis=1)
c100entrop_mean = -np.sum(c100prob_mean * np.log(c100prob_mean + 1e-5), axis=1)
c10entrop_single = -np.sum(c10probs[0] * np.log(c10probs[0] + 1e-5), axis=1)
c100entrop_single = -np.sum(c100probs[0] * np.log(c100probs[0] + 1e-5), axis=1)

scaled_c10prob_mean = raise_prob_power(c10prob_mean, 1/0.49237999028431956, -1)
scaled_c100prob_mean = raise_prob_power(c100prob_mean, 1/0.49237999028431956, -1)
scaled_c10entrop_mean = -np.sum(scaled_c10prob_mean * np.log(scaled_c10prob_mean + 1e-5), axis=1)
scaled_c100entrop_mean = -np.sum(scaled_c100prob_mean * np.log(scaled_c100prob_mean + 1e-5), axis=1)

fig, ax = plt.subplots(1,3,figsize=(15,3))

ax = plt.subplot(1,3,1)
_ = plt.hist(c10entrop_single, bins=100, alpha=0.7, density=True)
_ = plt.hist(c100entrop_single, bins=100, alpha=0.7, density=True)
plt.ylim((0, 2.7))

ax = plt.subplot(1,3,2)
_ = plt.hist(c10entrop_mean, bins=100, alpha=0.7, density=True)
_ = plt.hist(c100entrop_mean, bins=100, alpha=0.7, density=True)
plt.ylim((0, 2.7))

ax = plt.subplot(1,3,3)
_ = plt.hist(scaled_c10entrop_mean, bins=100, alpha=0.7, density=True)
_ = plt.hist(scaled_c100entrop_mean, bins=100, alpha=0.7, density=True)
plt.ylim((0, 2.7))

In [None]:
eps = 1e-3
numbin = 20

hist1, bin1 = np.histogram(c10entrop_single, bins=numbin, density=True)
hist2, bin2 = np.histogram(c100entrop_single, bins=bin1, density=True)

hist3, bin3 = np.histogram(c10entrop_mean, bins=numbin, density=True)
hist4, bin4 = np.histogram(c100entrop_mean, bins=bin3, density=True)

hist5, bin5 = np.histogram(scaled_c10entrop_mean, bins=numbin, density=True)
hist6, bin6 = np.histogram(scaled_c100entrop_mean, bins=bin5, density=True)

print(np.sum((hist1 - hist2)*np.log((hist1+eps)/(hist2+eps))),
np.sum((hist3 - hist4)*np.log((hist3+eps)/(hist4+eps))),
np.sum((hist5 - hist6)*np.log((hist5+eps)/(hist6+eps))))

print(np.mean(c100entrop_single)-np.mean(c10entrop_single), 
      np.mean(c100entrop_mean)-np.mean(c10entrop_mean),
      np.mean(scaled_c100entrop_mean)-np.mean(scaled_c10entrop_mean))

print(np.median(c100entrop_single)-np.median(c10entrop_single),
      np.median(c100entrop_mean)-np.median(c10entrop_mean),
      np.median(scaled_c100entrop_mean)-np.median(scaled_c10entrop_mean))

In [None]:
res = []
for i in range(30):
    c10entrop_single = -np.sum(c10probs[i] * np.log(c10probs[i] + 1e-5), axis=1)
    c100entrop_single = -np.sum(c100probs[i] * np.log(c100probs[i] + 1e-5), axis=1)
    res.append(np.median(c100entrop_single)-np.median(c10entrop_single))

In [None]:
np.mean(res), np.std(res)

In [None]:
single_roc = [[], []]
mean_roc = [[], []]
scaled_roc = [[], []]

for entr in np.linspace(0, 3):
    single_tpr = np.sum(c100entrop_single > entr)/len(c100entrop_single)
    single_fpr = np.sum(c10entrop_single > entr)/len(c10entrop_single)
    single_roc[0].append(single_fpr)
    single_roc[1].append(single_tpr)
    
    mean_tpr = np.sum(c100entrop_mean > entr)/len(c100entrop_mean)
    mean_fpr = np.sum(c10entrop_mean > entr)/len(c10entrop_mean)
    mean_roc[0].append(mean_tpr)
    mean_roc[1].append(mean_fpr)
    
    scaled_tpr = np.sum(scaled_c100entrop_mean > entr)/len(scaled_c100entrop_mean)
    scaled_fpr = np.sum(scaled_c10entrop_mean > entr)/len(scaled_c10entrop_mean)
    scaled_roc[0].append(scaled_tpr)
    scaled_roc[1].append(scaled_fpr)

plt.plot(single_roc[1], single_roc[0])
plt.plot(mean_roc[1], mean_roc[0])
plt.plot(scaled_roc[1], scaled_roc[0])
from sklearn.metrics import auc
print(auc(single_roc[1], single_roc[0]), auc(mean_roc[1], mean_roc[0]), auc(scaled_roc[1], scaled_roc[0]))

In [None]:
num_samples = 30
randomize = True
random_sample_count = 5
random_samples = []
for i in range(num_samples):
    if i==num_samples-1 and randomize:
        randperms = [list(range(num_samples))]
    elif randomize:
        randperms = [np.random.choice(a=np.arange(num_samples), size=i+1, replace=False) 
                     for _ in range(random_sample_count)]
    else:
        randperms = [list(range(i+1))]
    random_samples.append(randperms)

In [None]:
with open('old_files/distance_plot_test_preds.pkl', 'wb') as openf:
    pickle.dump(indiv_scaled_test_probs, openf)

In [None]:
rand_indices = []
valid_size = 50
for k in range(valid_logits.shape[1]//valid_size):
    randind = torch.arange(valid_size) + valid_size*k
#     randind = torch.randperm(valid_logits.shape[1])[:valid_size]
    rand_indices.append(randind)
    
for idx in rand_indices:
    print(np.unique(valid_targets[idx].cpu().numpy(), return_counts=True)[1])

In [None]:
from torch.nn import NLLLoss, CrossEntropyLoss
nllcrit, crosscrit = NLLLoss(), CrossEntropyLoss()
softmax_mean = lambda x: torch.mean(F.softmax(x, dim=2), dim=0)
def softmax_median(x):
    med_prob = torch.median(F.softmax(x, dim=2), dim=0)[0]
    return med_prob / torch.sum(med_prob, dim=1).unsqueeze(1)
    
all_robust = []
indiv_scaled_test_probs = []
temp_range = np.exp(np.linspace(start=-3, stop=3.0001, num=200))
valid_probs = F.softmax(valid_logits, dim=2)

for index in [29]:  #range(30)
    robust_dict = {'scale_LP': [], 'scale_MP': [], 'scale_trimmed_LP': [], 
               'joint_scale_LP': [], 'joint_scale_MP': [], 'joint_scale_trimmed_LP': [], 
               'post_LP_scale': [], 'post_MP_scale': [], 'post_trimmed_MP_scale': []}
    randperm = random_samples[index][0]
    test_mean_probs = np.mean(probs[randperm,:,:], axis=0)
    test_median_probs = np.median(probs[randperm,:,:], axis=0)
    test_median_probs = test_median_probs / np.sum(test_median_probs, axis=1)[:,None]
    num_out = -1  # len(randperm)//6
    if num_out > 0:
        test_trimmed_logits = return_entropy_sorted(logits[randperm,:,:])[num_out:-num_out]
    else:
        test_trimmed_logits = logits[randperm,:,:]
    test_trimmed_probs = F.softmax(test_trimmed_logits, dim=2).cpu().numpy()

    for randind in tqdm_notebook(rand_indices[:10]):
        curr_valid_logits = valid_logits[:,randind,:][randperm,:,:]
        curr_valid_probs = valid_probs[:,randind,:][randperm,:,:]
        curr_valid_targets = valid_targets[randind].type(torch.LongTensor).to(device)
        curr_valid_mean_probs = torch.mean(curr_valid_probs, dim=0)
        curr_valid_median_probs = softmax_median(curr_valid_logits)
        if num_out > 0:
            curr_valid_trimmed_logits = return_entropy_sorted(curr_valid_logits)[num_out:-num_out]
        else:
            curr_valid_trimmed_logits = curr_valid_logits
        curr_valid_trimmed_mean_probs = torch.mean(F.softmax(curr_valid_trimmed_logits, dim=2), dim=0)

        indiv_temps = []
        for k in range(len(randperm)):
            losses = np.array([crosscrit(torch.log(curr_valid_probs[k])/temp, curr_valid_targets).item() 
                               for temp in temp_range])
            losses = np.array(losses)
            temps = temp_range[np.isnan(losses)==False]
            losses = losses[np.isnan(losses)==False]
            indiv_temps.append(temps[np.argmin(losses)])

        scaled_test_logits = logits[randperm,:,:]/torch.tensor(indiv_temps).to(device).view(len(randperm),1,1)
        scaled_test_probs = F.softmax(scaled_test_logits, dim=2)
        if num_out > 0:
            scaled_trimmed_test_probs = F.softmax(return_entropy_sorted(scaled_test_logits)[num_out:-num_out], dim=2).cpu().numpy()
        else:
            scaled_trimmed_test_probs = F.softmax(return_entropy_sorted(scaled_test_logits), dim=2).cpu().numpy()
        indiv_scaled_test_probs.append(scaled_test_probs)
        scaled_test_probs = scaled_test_probs.cpu().numpy()
        robust_dict['scale_LP'].append(get_acc_ece_nll(np.mean(scaled_test_probs, axis=0), targets))
#         median_prob = np.median(scaled_test_probs, axis=0)
#         median_prob = median_prob / np.sum(median_prob, axis=1)[:,None]
#         robust_dict['scale_MP'].append(get_acc_ece_nll(median_prob, targets))
#         robust_dict['scale_trimmed_LP'].append(get_acc_ece_nll(np.mean(scaled_trimmed_test_probs, axis=0), targets))

#         losses = [nllcrit(torch.log(softmax_mean(curr_valid_logits/temp)), curr_valid_targets).item() for temp in temp_range]
#         losses = np.array(losses)
#         temps = temp_range[np.isnan(losses)==False]
#         losses = losses[np.isnan(losses)==False]
#         final_temp = temps[np.argmin(losses)]
#         ensemble_temp = final_temp
#         scaled_test_probs = F.softmax(logits[randperm,:,:]/final_temp, dim=2).cpu().numpy()
#         robust_dict['joint_scale_LP'].append(get_acc_ece_nll(np.mean(scaled_test_probs, axis=0), targets))

#         losses = [nllcrit(torch.log(softmax_mean(curr_valid_trimmed_logits/temp)), curr_valid_targets).item() 
#                   for temp in temp_range]
#         losses = np.array(losses)
#         temps = temp_range[np.isnan(losses)==False]
#         losses = losses[np.isnan(losses)==False]
#         final_temp = temps[np.argmin(losses)]
#         scaled_test_probs = F.softmax(test_trimmed_logits/final_temp, dim=2).cpu().numpy()
#         robust_dict['joint_scale_trimmed_LP'].append(get_acc_ece_nll(np.mean(scaled_test_probs, axis=0), targets))

#         losses = [nllcrit(torch.log(softmax_median(curr_valid_logits/temp)), curr_valid_targets).item() for temp in temp_range]
#         losses = np.array(losses)
#         temps = temp_range[np.isnan(losses)==False]
#         losses = losses[np.isnan(losses)==False]
#         final_temp = temps[np.argmin(losses)]
#         scaled_test_probs = softmax_median(logits[randperm,:,:]/final_temp).cpu().numpy()
#         robust_dict['joint_scale_MP'].append(get_acc_ece_nll(scaled_test_probs, targets))

        losses = [crosscrit(torch.log(curr_valid_mean_probs)/temp, curr_valid_targets).item() for temp in temp_range]
        losses = np.array(losses)
        temps = temp_range[np.isnan(losses)==False]
        losses = losses[np.isnan(losses)==False]
        final_temp = temps[np.argmin(losses)]
        scaled_test_probs = test_mean_probs**(1/final_temp)
        scaled_test_probs = scaled_test_probs / np.sum(scaled_test_probs, axis=1)[:,None]
        robust_dict['post_LP_scale'].append(get_acc_ece_nll(scaled_test_probs, targets))

#         losses = [crosscrit(torch.log(curr_valid_median_probs)/temp, curr_valid_targets).item() for temp in temp_range]
#         losses = np.array(losses)
#         temps = temp_range[np.isnan(losses)==False]
#         losses = losses[np.isnan(losses)==False]
#         final_temp = temps[np.argmin(losses)]
#         scaled_test_probs = test_median_probs**(1/final_temp)
#         scaled_test_probs = scaled_test_probs / np.sum(scaled_test_probs, axis=1)[:,None]
#         robust_dict['post_MP_scale'].append(get_acc_ece_nll(scaled_test_probs, targets))

#         losses = [crosscrit(torch.log(curr_valid_trimmed_mean_probs)/temp, curr_valid_targets).item() for temp in temp_range]
#         losses = np.array(losses)
#         temps = temp_range[np.isnan(losses)==False]
#         losses = losses[np.isnan(losses)==False]
#         final_temp = temps[np.argmin(losses)]
#         scaled_test_probs = np.mean(test_trimmed_probs, axis=0)**(1/final_temp)
#         scaled_test_probs = scaled_test_probs / np.sum(scaled_test_probs, axis=1)[:,None]
#         robust_dict['post_trimmed_MP_scale'].append(get_acc_ece_nll(scaled_test_probs, targets))
    all_robust.append(robust_dict)

In [None]:
for key in robust_dict:
    arr = np.array(robust_dict[key])
    if len(arr) > 0:
    #     print(key, arr)
        minloss = min(arr[:,2])
#         arr = arr[arr[:,2]<1.5*minloss]
    #     print(key, np.mean(arr, axis=0), np.std(arr, axis=0))
        print(key + '  ' + ' '.join(map(str, np.mean(arr, axis=0))))
print('\n')
for key in robust_dict:
    arr = np.array(robust_dict[key])
    if len(arr) > 0:
    #     print(key, arr)
        minloss = min(arr[:,2])
#         arr = arr[arr[:,2]<1.5*minloss]
    #     print(key, np.mean(arr, axis=0), np.std(arr, axis=0))
        print(key + '  ' + ' '.join(map(str, np.std(arr, axis=0))))

In [None]:
for robust_dict in all_robust:
    arr = np.array(robust_dict['post_LP_scale'])
    if len(arr) > 0:
        minloss = min(arr[:,2])
        print(minloss*1.5)
        arr = arr[arr[:,2]<1.5*minloss]
        print(np.mean(arr, axis=0), np.std(arr, axis=0))


In [None]:
robust_dict['post_LP_scale']

In [None]:
# with open('all_temperatures_cifar10', 'wb') as openf:
#     temp_list = [indiv_temps, ensemble_temp]
#     pickle.dump(temp_list, openf)
# all_scaling = []

with open('all_temperatures_imagenette', 'rb') as openf:
    tlist = pickle.load(openf)

print(np.mean(tlist[0]), np.std(tlist[0]), tlist[1])
scales = torch.tensor(tlist[0]).to(device).view(-1,1,1)

indiv_scaled_probs = F.softmax(logits/scales, dim=2).cpu().numpy()
final_scaled_probs = F.softmax(logits/tlist[1], dim=2).cpu().numpy()
all_scaling.append([indiv_scaled_probs, final_scaled_probs, targets])

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15,3))
bins = [(p + 1) / 30.0 for p in range(30)]
# indices = [13,1,24] # np.random.choice(range(30), 3)
xs = [(i+1)/30. for i in range(30)]
headers = ['CIFAR10', 'CIFAR100', 'Imagenette']
xlims = [(0.22, 1), (0.22, 1), (0.22, 1)]
# ylims = [(-.1, .25), (-0.32, .2), (-0.02, .3)]
ylims = [(-.3, .23), (-0.3, .23), (-0.3, .23)]

for data in range(3):
    indiv_scaled_probs, final_scaled_probs, targets = all_scaling[data]
    axi = plt.subplot(1, 3, data+1)
    plt.xlim((0.3, 1))
    plt.ylim((-.31, .25))
    plt.title(headers[data], fontsize=20)
    axi.tick_params(axis='x', which='major', labelsize=12)
    axi.tick_params(axis='y', which='major', labelsize=16)

    for i in range(30):
        res1 = nbutils.calculate_ECE(indiv_scaled_probs[i,:,:], targets, ECE_bin=bins)
        ys = [(a / b) if b>0 else 0 for a,b in zip(res1[3], res1[4])]
        plt.plot(xs, [y - x for x,y in zip(xs, ys)], 'o-', color='C0', linewidth=2, alpha=0.2, markersize=4)

    res1 = nbutils.calculate_ECE(np.mean(indiv_scaled_probs, axis=0), targets, ECE_bin=bins)
    ys = [(a / b) if b>0 else 0 for a,b in zip(res1[3], res1[4])]
    plt.plot(xs, [y - x for x,y in zip(xs, ys)], 'bo-', linewidth=5, alpha=0.5, markersize=8)

    for i in range(30):
        res1 = nbutils.calculate_ECE(final_scaled_probs[i,:,:], targets, ECE_bin=bins)
        ys = [(a / b) if b>0 else 0 for a,b in zip(res1[3], res1[4])]
        plt.plot(xs, [y - x for x,y in zip(xs, ys)], 'o-', color='tab:orange', linewidth=2, alpha=0.2, markersize=4)

    res1 = nbutils.calculate_ECE(np.mean(final_scaled_probs, axis=0), targets, ECE_bin=bins)
    ys = [(a / b) if b>0 else 0 for a,b in zip(res1[3], res1[4])]
    plt.plot(xs, [y - x for x,y in zip(xs, ys)], 'ro-', linewidth=5, alpha=0.5, markersize=8)

    plt.plot([0,1], [0,0], 'k-', linewidth=3, alpha=0.4)
    axi.grid(axis='both', color='k', linewidth=3, alpha=0.1)
    plt.xticks(ticks=[0.4, 0.7, 1.0], labels=[])

from matplotlib.lines import Line2D
custom_lines = [Line2D([0], [0], color='C0', lw=8, alpha=0.5),
                Line2D([0], [0], color='b', lw=8, alpha=0.5),
                Line2D([0], [0], color='tab:orange', lw=8, alpha=0.5),
                Line2D([0], [0], color='r', lw=8, alpha=0.5),
                ]

fig.legend( custom_lines, ['Individual [B] scaled models', 
                           'Pooled [B] scaled models',
                           'Individual [C] scaled models',
                           'Pooled [C] scaled models'
                          ], loc='lower center', 
           prop={'size': 16}, ncol=2, handlelength=1, borderaxespad=-.5, frameon=False)
plt.savefig('plots/temp_scaled_calibration_curve.pdf', bbox_inches='tight')

In [None]:
fig, axe = plt.subplots(1,2,figsize=(10,3))

bins = [(p + 1) / 30.0 for p in range(30)]
ret = nbutils.calculate_ECE(probs[0,:,:], targets, ECE_bin=bins)
xs = [(i+1)/30. for i in range(30)]
ys = [(a / b) if b>0 else 0 for a,b in zip(ret[3], ret[4])]
axi = plt.subplot(1,2,1)
plt.title('Calibration Plot', fontsize=20)
axi.tick_params(axis='x', which='major', labelsize=12)
axi.tick_params(axis='y', which='major', labelsize=16)
plt.plot([0,1], [0,0], 'k--')
plt.plot(xs, [y - x for x,y in zip(xs, ys)], 'o-', color='C3', linewidth=5, alpha=0.6, markersize=8)
axi.grid(axis='both', color='k', linewidth=3, alpha=0.1)
    
axi = plt.subplot(1,2,2)
plt.title('Confidence Histogram', fontsize=20)
axi.tick_params(axis='x', which='major', labelsize=12)
axi.tick_params(axis='y', which='major', labelsize=16)
_ = plt.hist(np.max(probs[0,:,:], axis=1).flatten(), bins=50, density=True)
axi.grid(axis='both', color='k', linewidth=3, alpha=0.1)
plt.savefig('calibration_example.pdf', bbox_inches='tight')

In [None]:
robust = np.array(robust)
robust = robust[robust[:,0] > 0.3,:]
# for k in range(valid_logits.shape[1]//50):
#     print(np.unique(valid_targets[50*k:50*(k+1)].cpu().numpy(), return_counts=True))
np.mean(robust, axis=0), np.std(robust, axis=0)

In [None]:
with open('network_mixup_1.0_IMAGENETTE_valid_test.pickle', 'rb') as openf:
    dict_to_save = pickle.load(openf)
    valid_logits = dict_to_save['valid_logits']
    logits = dict_to_save['test_logits']
    probs = F.softmax(logits, dim=2).cpu().numpy()
    valid_targets = dict_to_save['valid_targets']
    targets = dict_to_save['test_targets']

In [None]:
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import PercentFormatter
pcfmt0 = PercentFormatter(xmax=1, decimals=0, symbol='%', is_latex=False)
pcfmt1 = PercentFormatter(xmax=1, decimals=1, symbol='%', is_latex=False)

fig, ax = plt.subplots(3, figsize=(25,8))
titles = ['Accuracy', 'ECE', 'NLL', 'Brier'] #, 'Brier']
labels = ['[B] Linear', '[B] Median', '[B] Trimmed Linear',
          '[C] LInear', '[C] Median', '[C] Trimmed Linear',
          '[D] Linear', '[D] Median', '[D] Trimmed Linear'
         ]
xs = list(range(1,num_samples+1))

count = 1
for index in [1,2,3]:
    axi = plt.subplot(1, 3, count)
    count += 1
    if index == 1:
        axi.yaxis.set_major_formatter(pcfmt0)
    axi.tick_params(axis='y', which='major', labelsize=24)
    axi.tick_params(axis='x', which='major', labelsize=18)
    plt.title(titles[index], fontdict={'fontsize': 30})
    axi.grid(axis='both', color='k', linewidth=3, alpha=0.1)

    for key in all_robust[0]:
        means = [np.mean(all_robust[idx][key], axis=0)[index] for idx in range(num_samples)]
        stds = [np.std(all_robust[idx][key], axis=0)[index] for idx in range(num_samples)]
        axi.errorbar(xs, y=means, yerr=None, capsize=4.0, label=key, linewidth=5, alpha=0.8)
        
    handles, labels_ = axi.get_legend_handles_labels()

leg = fig.legend(handles, labels, ncol=5, loc='lower center',
           handlelength=1, borderaxespad=-0.45, prop={'size': 26}, frameon=False, markerscale=4)
for legobj in leg.legendHandles:
    legobj.set_linewidth(10)
    
# plt.savefig('plots/Cifar10_groups_and_pools.pdf', bbox_inches='tight')

In [None]:
all_robust[25]['joint_scale_LP']

In [None]:
# acc_ece_nll_percentile.shape
# probs1 = probs
# targets1 = targets
# np.mean(temp_avg, axis=0)
dict_to_save = {
#                 'train_logits': train_logits.cpu().numpy().tolist(),
                'valid_logits': valid_logits, # .cpu().numpy().tolist(),
                'test_logits': logits, # .tolist(),
#                 'train_targets': train_targets.cpu().numpy().tolist(),
                'valid_targets': valid_targets, #.cpu().numpy().tolist(),
                'test_targets': targets # .tolist(),
               }
with open('network_mixup_1.0_cifar10_valid_test.pickle', 'wb') as openf:
    pickle.dump(dict_to_save, openf)

In [None]:
# final_prob = [probs1[:30,:,:], probs2[:30,:,:], probs3[:30,:,:]]
# all_targets = [targets1, targets2, targets3]
# probs3 = probs
# targets3 = targets
dict2save = {'probs': final_prob,
             'targets': all_targets}
with open('pooling_underconfident_plot_data.pkl', 'wb') as openf:
    pickle.dump(dict2save, openf)

In [None]:
fig, ax = plt.subplots(1, figsize=(15,3))
bins = [(p + 1) / 30.0 for p in range(30)]
# indices = [13,1,24] # np.random.choice(range(30), 3)
xs = [(i+1)/30. for i in range(30)]
headers = ['CIFAR10 under-confident', 'CIFAR100 over-confident ', 'Imagewoof near-calibrated']
xlims = [(0.22, 1), (0.22, 1), (0.22, 1)]
# ylims = [(-.1, .25), (-0.32, .2), (-0.02, .3)]
ylims = [(-.3, .23), (-0.3, .23), (-0.3, .23)]

for data in range(3):
    # plt.ylim((-.1,.25))
    axi = plt.subplot(1, 3, data+1)
    plt.xlim(xlims[data])
    plt.ylim(ylims[data])
    probs = final_prob[data]
    targets = all_targets[data]
    axi.tick_params(axis='both', which='major', labelsize=16)
    axi.tick_params(axis='both', which='minor', labelsize=16)
    plt.title(headers[data], fontdict={'fontsize': 18})

    for i in range(30):
        res1 = nbutils.calculate_ECE(probs[i,:,:], targets, ECE_bin=bins)
        ys = [(a / b) if b>0 else 0 for a,b in zip(res1[3], res1[4])]
        plt.plot(xs, [y - x for x,y in zip(xs, ys)], 'bo-', linewidth=2, alpha=0.2, markersize=4)

    plt.plot([0,1], [0,0], 'k-', linewidth=3, alpha=0.4)
    axi.grid(axis='both', color='k', linewidth=3, alpha=0.1)
    res1 = nbutils.calculate_ECE(np.mean(probs, axis=0), targets, ECE_bin=bins)
    ys = [(a / b) if b>0 else 0 for a,b in zip(res1[3], res1[4])]
    plt.plot(xs, [y - x for x,y in zip(xs, ys)], 'ro-', linewidth=5, alpha=0.5, markersize=8)
    
from matplotlib.lines import Line2D
custom_lines = [Line2D([0], [0], color='b', lw=8, alpha=0.5),
                Line2D([0], [0], color='r', lw=8, alpha=0.5),
                ]

fig.legend( custom_lines, ['Individual Model', 'Pooled Model'], loc='lower center', 
           prop={'size': 18}, ncol=2, handlelength=1, borderaxespad=-.4, frameon=False)
plt.savefig('plots/pooling_underconfident_V3.pdf', bbox_inches='tight')

In [None]:
acc_ece_nll_avg = []
# acc_ece_nll_percentile = []
# acc_ece_nll_log_avg = []
acc_ece_nll_indiv = []
# acc_ece_nll_harmon = []

for index in tqdm_notebook(range(num_samples)):
#     acc, ece, nll, brier = get_acc_ece_nll(probs[index,:,:], targets)
    acc, ece, nll, brier = get_acc_ece_nll(scaled_probs[index,:,:], targets)

    acc_ece_nll_indiv.append((index+1, acc, ece, nll, brier))
#     temp_avg, temp_logavg, temp_harmon = [], [], []
#     temp_percentile = [[] for i in range(3)]
    
#     for perm in random_samples[index]:
#         current_prob = probs[perm,:,:]
        
#         avg_prob = np.mean(current_prob, axis=0)
#         acc, ece, nll, brier = get_acc_ece_nll(avg_prob, targets)
#         temp_avg.append((acc, ece, nll, brier))

#         for q in [25,50,75]:
#             percentile_prob = np.percentile(current_prob, q=q, axis=0)
#             percentile_prob = percentile_prob/np.sum(percentile_prob, axis=1)[:, None]
#             acc, ece, nll, brier = get_acc_ece_nll(percentile_prob, targets)
#             temp_percentile[(q-25)//25].append((acc, ece, nll, brier))

#         log_avg_prob = np.exp(np.mean(np.log(current_prob), axis=0))
#         log_avg_prob = log_avg_prob/np.sum(log_avg_prob, axis=1)[:, None]
#         acc, ece, nll, brier = get_acc_ece_nll(log_avg_prob, targets)
#         temp_logavg.append((acc, ece, nll, brier))

#         harmon_prob = np.mean(1/current_prob, axis=0)
#         harmon_prob = harmon_prob/np.sum(harmon_prob, axis=1)[:, None]
#         acc, ece, nll, brier = get_acc_ece_nll(harmon_prob, targets)
#         temp_harmon.append((acc, ece, nll, brier))
        
#     acc_ece_nll_avg.append({'mean': np.mean(temp_avg, axis=0), 'std': np.std(temp_avg, axis=0)})
#     acc_ece_nll_log_avg.append({'mean': np.mean(temp_logavg, axis=0), 'std': np.std(temp_logavg, axis=0)})
#     acc_ece_nll_harmon.append({'mean': np.mean(temp_harmon, axis=0), 'std': np.std(temp_harmon, axis=0)})
#     acc_ece_nll_percentile.append([{'mean': np.mean(temp_percentile[q], axis=0),
#                                     'std': np.std(temp_percentile[q], axis=0)}
#                                    for q in range(3)])

In [None]:
arr = np.array(acc_ece_nll_indiv)
np.mean(arr, axis=0)[1:], np.std(arr, axis=0)[1:]

In [None]:
acc_ece_nll_tempscaling = []
tempscale = EnsembleTempScaling(model=ens_model, num_class=num_class, device=device)
remove_entropy = 0.33
# validation = loader_dict['no-augment_valid']

for i in range(num_samples):
    temp_temper = []
    for perm in random_samples[i]:
        curr_test_logits = logits[perm,:,:]
        curr_valid_logits = valid_logits[perm,:,:]
        if remove_entropy is not None:
            num_remove = int(curr_valid_logits.shape[0] * remove_entropy)
            each_side = num_remove // 2
            if each_side > 0:
                curr_valid_logits = return_entropy_sorted(curr_valid_logits)[each_side:-each_side]
                curr_test_logits = return_entropy_sorted(curr_test_logits)[each_side:-each_side]
        tempscale.optimize(None, lr=.1, epoch=3000, eval=False, logits=curr_valid_logits, targets=valid_targets)
        # probs = F.softmax(tempscale.forward(dataloader=dataloader_test, ens_eval=False), dim=2).detach().cpu().numpy()
        tscale_probs = F.softmax(tempscale.transform(curr_test_logits), dim=2).detach().cpu().numpy()
        temp_temper.append(tuple(get_acc_ece_nll(np.mean(tscale_probs, axis=0), targets)))
    acc_ece_nll_tempscaling.append({'mean': np.mean(temp_temper, axis=0), 'std': np.std(temp_temper, axis=0)})

In [None]:
acc_ece_nll_matscale = []
matscale = EnsembleMatrixScaling(model=ens_model, num_class=num_class, device=device)
# validation = loader_dict['no-augment_valid']
# cifar_test_dataset = datutils.reformed_CIFAR100(datutils.base, train=False,
#                                                 transform=datutils.test_augment, download=False, end=500)
# validation = torch.utils.data.DataLoader(cifar_test_dataset, batch_size=500,
#                                                shuffle=False, num_workers=5)
for i in range(num_samples):
    temp_matrix = []
    for perm in random_samples[i]:
        curr_test_logits = logits[perm,:,:]
        curr_valid_logits = valid_logits[perm,:,:]
        matscale.optimize(None, lr=.01, epoch=10000, eval=False, logits=curr_valid_logits, targets=valid_targets)
        # probs = F.softmax(tempscale.forward(dataloader=dataloader_test, ens_eval=False), dim=2).detach().cpu().numpy()
        matscale_probs = F.softmax(matscale.transform(curr_test_logits), dim=2).detach().cpu().numpy()
        temp_matrix.append(tuple(get_acc_ece_nll(np.mean(matscale_probs, axis=0), targets)))
    acc_ece_nll_matscale.append({'mean': np.mean(temp_matrix, axis=0), 'std': np.std(temp_matrix, axis=0)})

In [None]:
acc_ece_nll_vecscaling = []
vecscale = EnsembleMatrixScaling(model=ens_model, num_class=num_class, device=device)
# validation = loader_dict['no-augment_valid']
# cifar_test_dataset = datutils.reformed_CIFAR100(datutils.base, train=False,
#                                                 transform=datutils.test_augment, download=False, end=500)
# validation = torch.utils.data.DataLoader(cifar_test_dataset, batch_size=500,
#                                                shuffle=False, num_workers=5)
for i in range(num_samples):
    temp_vector = []
    for perm in random_samples[i]:
        curr_test_logits = logits[perm,:,:]
        curr_valid_logits = valid_logits[perm,:,:]
        vecscale.optimize(None, lr=.01, epoch=10000, eval=False, logits=curr_valid_logits, targets=valid_targets)
        # probs = F.softmax(tempscale.forward(dataloader=dataloader_test, ens_eval=False), dim=2).detach().cpu().numpy()
        vecscale_probs = F.softmax(vecscale.transform(curr_test_logits), dim=2).detach().cpu().numpy()
        temp_vector.append(tuple(get_acc_ece_nll(np.mean(vecscale_probs, axis=0), targets)))
    acc_ece_nll_vecscaling.append({'mean': np.mean(temp_vector, axis=0), 'std': np.std(temp_vector, axis=0)})

In [None]:
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import PercentFormatter
pcfmt0 = PercentFormatter(xmax=1, decimals=0, symbol='%', is_latex=False)
pcfmt1 = PercentFormatter(xmax=1, decimals=1, symbol='%', is_latex=False)
quantiles = [25,50,75]

acc_ece_nll_avg = np.array(acc_ece_nll_avg)
acc_ece_nll_percentile = np.array(acc_ece_nll_percentile)
acc_ece_nll_indiv = np.array(acc_ece_nll_indiv)
acc_ece_nll_log_avg = np.array(acc_ece_nll_log_avg)
# acc_ece_nll_harmon = np.array(acc_ece_nll_harmon)
# acc_ece_nll_tempscaling = np.array(acc_ece_nll_tempscaling)
# acc_ece_nll_matscale = np.array(acc_ece_nll_matscale)
# acc_ece_nll_vecscaling = np.array(acc_ece_nll_vecscaling)
# acc_ece_nll_scalescale = np.array(acc_ece_nll_scalescale)

fig, ax = plt.subplots(3, figsize=(25,8))
titles = ['Accuracy', 'ECE', 'NLL'] #, 'Brier']
xs = list(range(1,num_samples+1))

for index in range(3):
    axi = plt.subplot(1, 3, index+1)
#     axi.plot(xs, get_mean_array(acc_ece_nll_avg, index), 'k', label='Linear avg', linewidth=5, alpha=0.8)
    axi.errorbar(xs, y=get_mean_array(acc_ece_nll_avg, index), 
                 yerr=get_std_array(acc_ece_nll_avg, index), capsize=4.0,
                 c='k', label='Linear avg', linewidth=5, alpha=0.8)
    axi.errorbar(xs, y=get_mean_array(acc_ece_nll_log_avg, index), 
                 yerr=get_std_array(acc_ece_nll_log_avg, index), capsize=4.0,
                 label='Geometric pool', linewidth=5, alpha=0.6)
#     axi.plot(acc_ece_nll_harmon[:,0], acc_ece_nll_harmon[:,index+1], label='Harmonic pool', linewidth=4, alpha=0.6)
#     axi.plot(xs, acc_ece_nll_indiv[:,index+1], 'k--', label='Individual', linewidth=5, alpha=0.6)
#     axi.errorbar(xs, y=get_mean_array(acc_ece_nll_matscale, index), 
#                  yerr=get_std_array(acc_ece_nll_matscale, index), capsize=4.0,
#                  label='Matrix Scale', linewidth=4, alpha=0.6)
#     axi.errorbar(xs, y=get_mean_array(acc_ece_nll_vecscaling, index), 
#                  yerr=get_std_array(acc_ece_nll_vecscaling, index), capsize=4.0,
#                  label='Vector Scale', linewidth=4, alpha=0.6)
#     axi.errorbar(xs, get_mean_array(acc_ece_nll_tempscaling, index), 
#                  yerr=get_std_array(acc_ece_nll_tempscaling, index), capsize=4.0,
#                  label='Temp scale', linewidth=5, alpha=0.6)

    if index == 0:
        axi.yaxis.set_major_formatter(pcfmt1)
    elif index == 1:
        axi.yaxis.set_major_formatter(pcfmt0)
    else:
        axi.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        
    axi.tick_params(axis='y', which='major', labelsize=22)
    axi.tick_params(axis='x', which='major', labelsize=18)
#     axi.tick_params(axis='both', which='minor', labelsize=20)
    axi.set_title(titles[index], fontdict={'fontsize': 24})
    axi.grid(axis='y', color='k', linewidth=3, alpha=0.1)
    
    for i in range(3):
        axi.errorbar(xs, get_mean_array(acc_ece_nll_percentile[:,i], index),
                 yerr=get_std_array(acc_ece_nll_percentile[:,i], index), capsize=4.0,
                 label='%dth perc'%(quantiles[i]), linewidth=5, alpha=0.6)
    
    handles, labels = axi.get_legend_handles_labels()

fig.legend(handles, labels, ncol=5, loc='lower center',
           handlelength=1, borderaxespad=-0.30, prop={'size': 22}, frameon=False)
# plt.savefig('plots/Imagenette_1k_alternate_pooling_errorbars.pdf', bbox_inches='tight')

In [None]:
acc_ece_nll_avg = np.array(acc_ece_nll_avg)
acc_ece_nll_percentile = np.array(acc_ece_nll_percentile)
# acc_ece_nll_indiv = np.array(acc_ece_nll_indiv)
acc_ece_nll_log_avg = np.array(acc_ece_nll_log_avg)
acc_ece_nll_harmon = np.array(acc_ece_nll_harmon)
# acc_ece_nll_tempscaling = np.array(acc_ece_nll_tempscaling)
# acc_ece_nll_matscale = np.array(acc_ece_nll_matscale)
# acc_ece_nll_vecscaling = np.array(acc_ece_nll_vecscaling)
quantiles = [25,50,75]

# with open('saved_models/Imagenette_1k_5_samples.ensemble', 'wb') as openfile:
#     pickle.dump({'avg': acc_ece_nll_avg,
#                  'perc': acc_ece_nll_percentile,
#                  'indiv': acc_ece_nll_indiv,
#                  'logavg': acc_ece_nll_log_avg,
#                  'harmon': acc_ece_nll_harmon,
#                  'temp': acc_ece_nll_tempscaling,
#                  'mat': acc_ece_nll_matscale,
#                  'vec': acc_ece_nll_vecscaling}, openfile)

print('Average\t\t\t\t', ['%3f'%item for item in acc_ece_nll_avg[-1]['mean']])
print('Geometric\t\t\t', ['%3f'%item for item in acc_ece_nll_log_avg[-1]['mean']])
# print('Harmonic\t\t\t', ['%3f'%item for item in acc_ece_nll_harmon[-1]['mean']])
# print('Temp scale\t\t\t', ['%3f'%item for item in acc_ece_nll_tempscaling[-1]['mean']])
# print('Matrix scaling\t\t\t', ['%3f'%item for item in acc_ece_nll_matscale[-1]['mean']])
# print('Vector scaling\t\t\t', ['%3f'%item for item in acc_ece_nll_vecscaling[-1]['mean']])
for perc in range(3):
    print(quantiles[perc],'percentile\t\t\t', ['%3f'%item for item in acc_ece_nll_percentile[-1,perc]['mean']])

In [None]:
logits, targets = ens_model.forward(dataloader_test, eval=False, return_target=True)
logits, targets = logits.detach().cpu(), targets.detach().cpu().numpy()

In [None]:
bins = [(p + 1) / 30.0 for p in range(30)]
start_bin = [0] + bins[:-1]
mid_bins = [0.5*(start_bin[i] + bins[i]) for i in range(len(bins))]

i = np.random.randint(low=0, high=34)
probs = F.softmax(logits[i]/0.5275, dim=1).numpy()
print(get_acc_ece_nll(probs, targets))
corr, tot = nbutils.calculate_ECE(probs, targets, ECE_bin=bins)[-2:]
plt.plot(mid_bins, [a/b if b > 0 else 0 for a,b in zip(corr, tot)])
plt.plot(mid_bins, mid_bins)
plt.show()

In [None]:
percentile_tempscale = []
valid_size = 50

for k in range(valid_logits.shape[1]//valid_size):
    curr_test_logits = logits
    randind = torch.randperm(valid_logits.shape[1])[:valid_size]
    curr_valid_logits = valid_logits[:,randind,:]
    curr_valid_targets = valid_targets[randind].type(torch.LongTensor).to(device)
    
    valid_25_probs = F.softmax(curr_valid_logits, dim=2).cpu().numpy()
    percent_prob = np.percentile(valid_25_probs, q=25, axis=0)
    valid_25_probs = torch.tensor(percent_prob / np.sum(percent_prob, axis=1)[:,None]).to(device)
    t = torch.nn.Parameter(1.5 * torch.ones(1, device=device))
    optimizer = torch.optim.SGD([t], lr=.1)
    criteria = torch.nn.NLLLoss()

    for epoch in range(3000):
        scaled_probs = F.softmax(torch.log(valid_25_probs) / t, dim=1)
        logprob = torch.log(scaled_probs)
        loss = criteria(logprob, curr_valid_targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(loss.item())
    
    test_25_probs = F.softmax(logits, dim=2).cpu().numpy()
    test_percent_prob = np.percentile(test_25_probs, q=25, axis=0)
    test_percent_prob = test_percent_prob / np.sum(test_percent_prob, axis=1)[:,None]
    test_percent_prob = np.log(test_percent_prob)/t.item()
    test_percent_prob = np.exp(test_percent_prob)
    test_percent_prob = test_percent_prob / np.sum(test_percent_prob, axis=1)[:,None]
    percentile_tempscale.append(get_acc_ece_nll(test_percent_prob, targets))

In [None]:
percentile_tempscale = np.array(percentile_tempscale)
percentile_tempscale = percentile_tempscale[percentile_tempscale[:,0] > 0.4,:]
np.mean(percentile_tempscale, axis=0), np.std(percentile_tempscale, axis=0)

In [None]:
# with open('saved_models/100AugmentedResNet34IMAGEWOOF_HighMixUp23.model', 'rb') as openfile:
#     dictloaded = torch.load(openfile)
# dictloaded['acc']
acc_ece_nll_indiv

In [None]:
import imp
imp.reload(models)

In [None]:
len(loader_dict['no-augment_valid'].dataset)

In [None]:
import os
import pickle

directory = 'saved_models/'
files = [directory+file for file in os.listdir(directory) if file.endswith('.mcmc') and 'full' in file]
files.sort()

for file in files:
    with open(file, 'rb') as openfile:
        print('\n' + file)
        a = pickle.load(openfile)
        for key in a.keys():
            if key not in ['samples', 'stats', 'performance']:
                print('{:<20s}'.format(key), '\t\t', a[key])
            elif key=='performance':
                for nest_key in a[key]:
                    print(nest_key, '\t\t\t', a[key][nest_key])

In [None]:
import json
with open('cifar10_950_50_split_mixup.json', 'r') as openf:
    dict_to_load = json.load(openf)

In [None]:
import json
with open('cifar10_950_50_split_no_augment.json') as openf:
    dict_to_load = json.load(openf)

In [None]:
valid_logits = torch.tensor(dict_to_load['valid_logits']).to(device)
valid_targets = torch.tensor(dict_to_load['valid_targets'])
test_logits = torch.tensor(dict_to_load['test_logits']).to(device)
test_targets = dict_to_load['test_targets']

In [None]:
valid_probs = F.softmax(valid_logits, dim=2)
eces = [get_acc_ece_nll(valid_probs[i].cpu().numpy(), np.array(valid_targets), ece_transform=lambda x: x)[1] 
        for i in range(valid_logits.shape[0])]

In [None]:
min_ece = min(eces)
ece_tensor = torch.tensor(eces).to(device) - min_ece
ece_tensor = ece_tensor.unsqueeze(1).unsqueeze(2)

In [None]:
new_targets = valid_targets.type(torch.LongTensor).to(device)

In [None]:
t = torch.nn.Parameter(1.5 * torch.ones(1, device=device))
d = torch.nn.Parameter(1.5 * torch.ones(1, device=device))

In [None]:
optimizer = torch.optim.SGD([t, d], lr=10)
criteria = torch.nn.NLLLoss()

In [None]:
for epoch in range(10000):
    scaled_probs = torch.pow(valid_probs, t + (ece_tensor)*d)
    probs = scaled_probs / torch.sum(scaled_probs, dim=2).unsqueeze(2)
    avg_prob = torch.mean(probs, dim=0)
    logprob = torch.log(avg_prob)
    loss = criteria(logprob, new_targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch%1000 == 0:
        print(t.item(), d.item(), loss.item())

In [None]:
ece_tensor*d

In [None]:
test_probs = F.softmax(test_logits, dim=2)
test_probs = torch.pow(test_probs, t + (ece_tensor)*d)
test_probs = test_probs / torch.sum(test_probs, dim=2).unsqueeze(2)
avg_test_prob = torch.mean(test_probs, dim=0)
get_acc_ece_nll(avg_test_prob.detach().cpu().numpy(), np.array(test_targets))

In [None]:
for epoch in range(2000):
    logits = valid_logits/t
    probs = F.softmax(logits, dim=2)
    avg_prob = torch.mean(probs, dim=0)
    logprob = torch.log(avg_prob)
    loss = criteria(logprob, new_targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(t.item(), loss.item())

In [None]:
test_probs = F.softmax(test_logits/t, dim=2)
avg_test_prob = torch.mean(test_probs, dim=0)
get_acc_ece_nll(avg_test_prob.detach().cpu().numpy(), targets)

In [None]:
import imp
imp.reload(nbutils)

In [None]:
import faiss
cifar_features = np.array(torch.load("/data02/rahul/cifar10_ensemble/cifar10_latent_feature/feature_mat_from_sim_clr.t"))
dim = 128
index = faiss.IndexFlatL2(dim)   # build the index
print(index.is_trained)
index.add(cifar_features[:1000,:])   

nb_neighbors = 1
dist_to_train, _ = index.search(cifar_features[50000:,:], nb_neighbors)
dist_to_train = dist_to_train.flatten()

In [None]:
nbins_dist = 10
dist_quantiles = np.percentile(dist_to_train.flatten(), q = np.linspace(0,100,nbins_dist+1))

def indices_given_distance(d_min, d_max):
    """
    filter the test set -- returns the indices of test samples that 
    are at a distance to train set in between d_min and d_max
    """
    condition = ( (dist_to_train >= d_min) * (dist_to_train <= d_max) ).astype(bool)
    n_test = 10000
    return np.arange(n_test)[condition]

In [None]:
temp_entropy, temp_ece, temp_nll, temp_acc = [], [], [], []
ens_preds = np.mean(F.softmax(logits, dim=2).cpu().numpy(), axis=0)

for k in range(len(dist_quantiles)-1):
    d_min, d_max = dist_quantiles[k], dist_quantiles[k+1]
    ind = indices_given_distance(d_min, d_max)
    preds = ens_preds[ind,:]
    temp_t = targets[ind]

    entropy = -preds * np.log(preds)
    entropy = np.sum(entropy, axis=1)
    res = get_acc_ece_nll(preds, temp_t, ece_transform=abs)

    temp_entropy.append( np.mean(entropy) )
    temp_ece.append( res[1] )
    temp_nll.append( res[2] )
    temp_acc.append( res[0] )


In [None]:
with open('/data02/rahul/ensemble_calibration_jsons/all_temperatures_cifar10', 'rb') as openf:
    content = pickle.load(openf)

In [None]:
content