In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
from scipy.special import softmax

import torch
from torchvision import transforms, datasets
# from sklearn.metrics import mean_squared_error

## Dataset

In [None]:
batch_size = 128
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
valset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=False,num_workers=3)

x_dev = []
y_dev = []
for x, y in valloader:
    x_dev.append(x.cpu().numpy())
    y_dev.append(y.cpu().numpy())

x_dev = np.concatenate(x_dev)
y_dev = np.concatenate(y_dev)
print(x_dev.shape)
print(y_dev.shape)

## Metrics

In [None]:
# Log likelihood
def get_ll(preds, targets):
    return np.log(1e-12 + preds[np.arange(len(targets)), targets]).mean()

# def get_RMSE(preds, targets):
#     pass

# Brier score
# gentler than log loss in penalizing inaccurate predictions.
def get_brier(preds, targets):
    one_hot_targets = np.zeros(preds.shape)
    one_hot_targets[np.arange(len(targets)), targets] = 1.0
    return np.mean(np.sum((preds - one_hot_targets) ** 2, axis=1))

def get_accuracy(preds, targets):
    yhat = np.argmax(preds, 1)
    accuracy = np.mean(yhat==targets)
    return accuracy

### CIFAR-10

In [None]:
result_name_list = ["Results/Regular_results", "Results/MCdrop_results", "Results/Ensemble_results", "Results/BBP_results/bbb",
                 "Results/BBP_results/lrt", "Results/Contrastive_Reasoning_results", "Results/Tent_results"]
names = ["Regular", 'MC Dropout', 'Bootstrap Ensemble', 'BBP: without lrt' , 'BBP: with lrt', 'Contrastive Reasoning', 'TENT']

metrics = np.zeros((len(result_name_list),3))

targets = y_dev

for idx, dir_name in enumerate(result_name_list):

    all_preds = np.load(dir_name + '/all_preds.npy')
    preds = all_preds[:, 0, :] # Just take at rotation 0

    ll = get_ll(preds, targets)
    brier = get_brier(preds, targets)
    acc = get_accuracy(preds, targets)
    
    metrics[idx,:] = ll,brier,acc

pd.DataFrame(metrics, columns=["Log-likelihood", "Brier score", "Accuracy"], index=names)

### CIFAR-10 rotations

In [None]:
result_name_list = ["Results/Regular_results", "Results/MCdrop_results", "Results/Ensemble_results", "Results/BBP_results/bbb",
                 "Results/BBP_results/lrt", "Results/Contrastive_Reasoning_results", "Results/Tent_results"]
names = ["Regular", 'MC Dropout', 'Bootstrap Ensemble', 'BBP: without lrt' , 'BBP: with lrt', 'Contrastive Reasoning', 'TENT']

metrics = np.zeros((len(result_name_list),3))

targets = np.repeat(y_dev,16)

for idx, dir_name in enumerate(result_name_list):

    all_preds = np.load(dir_name + '/all_preds.npy')
    N, R, C = all_preds.shape
    preds = all_preds.reshape(-1, C) # Over all rotations

    ll = get_ll(preds, targets)
    brier = get_brier(preds, targets)
    acc = get_accuracy(preds, targets)
    
    metrics[idx,:] = ll,brier,acc

pd.DataFrame(metrics, columns=["Log-likelihood", "Brier score", "Accuracy"], index=names)

### CIFAR-10-C

In [None]:
# labels = []
# for key in preds_dict.keys():
#     labels.append(key.split(".")[0])
# print(labels)

In [None]:
## Boxplot of metrics across distortion types

# plt.figure(figsize=(25,10))
# ax = plt.boxplot(ll_list, labels=labels)
# plt.xlabel("Distortion type"); plt.ylabel("Log likelihood")
# plt.grid()
# plt.show()

# plt.figure(figsize=(25,10))
# plt.boxplot(brier_list, labels=labels)
# plt.xlabel("Distortion type"); plt.ylabel("Brier score")
# plt.grid()
# plt.show()

In [None]:
result_name_list = ["Results/Regular_results", "Results/MCdrop_results", "Results/Ensemble_results", "Results/BBP_results/bbb",
                 "Results/BBP_results/lrt", "Results/Contrastive_Reasoning_results", "Results/Tent_results"]
names = ["Regular", 'MC Dropout', 'Bootstrap Ensemble', 'BBP: without lrt' , 'BBP: with lrt', 'Contrastive Reasoning', 'TENT']

metrics = np.zeros((len(result_name_list),3))
targets = np.tile(y_dev,5*19)

for idx, dir_name in enumerate(result_name_list):

    preds = np.load(dir_name + '/preds_CIFAR-10-C.npy')

    ll = get_ll(preds, targets)
    brier = get_brier(preds, targets)
    acc = get_accuracy(preds, targets)
    
    metrics[idx,:] = ll,brier,acc

pd.DataFrame(metrics, columns=["Log-likelihood", "Brier score", "Accuracy"], index=names)