In [2]:
import math
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F


from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

In [17]:
def bin_initializer(num_bins=10):
    """
    
    """
    bin_dict = {}
    for i in range(num_bins):
        bin_dict[i] = {}
        bin_dict[i]['count'] = 0 # number of samples in the bin 
        bin_dict[i]['conf'] = 0 # sum of the bin's samples confidence 
        bin_dict[i]['acc'] = 0 # sum of the bin's samples accuracy 
        bin_dict[i]['bin_acc'] = 0 # average accuracy 
        bin_dict[i]['bin_conf'] = 0 # average confidence 
    return bin_dict  

def populate_bins(confs, preds, labels, num_bins=10):
    """
    
    """
    bin_dict = bin_initializer(num_bins)
    num_test_samples = len(confs)

    for i in range(0, num_test_samples):
        confidence = confs[i]
        prediction = preds[i]
        label = labels[i]
        # find the right bin for sample i
        binn = int(math.ceil(((num_bins * confidence) - 1)))
        bin_dict[binn]['count'] += 1
        bin_dict[binn]['conf'] += confidence
        bin_dict[binn]['acc'] += (1 if (label == prediction) else 0)

    for binn in range(0, num_bins):
        if (bin_dict[binn]['count'] > 0):
            bin_dict[binn]['bin_acc'] = bin_dict[binn]['acc']/\
                                        bin_dict[binn]['count']
            bin_dict[binn]['bin_conf'] = bin_dict[binn]['conf']/\
                                        bin_dict[binn]['count']
    return bin_dict

In [18]:
bin_initializer(num_bins=2)
populate_bins(confs=[0.1,0.2,0.5,0.8], preds=[1,0,0,0], labels=[1,1,0,0], num_bins=2)

{0: {'count': 3,
  'conf': 0.8,
  'acc': 2,
  'bin_acc': 0.6666666666666666,
  'bin_conf': 0.26666666666666666},
 1: {'count': 1, 'conf': 0.8, 'acc': 1, 'bin_acc': 1.0, 'bin_conf': 0.8}}

In [20]:
def expected_calibration_error(confs, preds, labels, num_bins=10):
    """
    
    """
    bin_dict = populate_bins(confs, preds, labels, num_bins)
    num_samples = len(labels)
    ece = 0
    for i in range(num_bins):
        ece += (bin_dict[i]['count'] / num_samples) * \
        abs(bin_dict[i]['bin_acc'] - bin_dict[i]['bin_conf'])
    return ece

In [21]:
expected_calibration_error(confs=[0.1,0.2,0.5,0.8], preds=[1,0,0,0], labels=[1,1,0,0], num_bins=2)

0.35

In [26]:
def maximum_calibration_error(confs, preds, labels, num_bins=10):
    """
    """
    bin_dict = populate_bins(confs, preds, labels, num_bins)
    ce = []
    for i in range(num_bins):
        ce.append(abs(bin_dict[i]['bin_acc'] - bin_dict[i]['bin_conf']))
    return max(ce)

def average_calibration_error(confs, preds, labels, num_bins=10):
    """
    """
    bin_dict = populate_bins(confs, preds, labels, num_bins)
    non_empty_bins = 0
    ace = 0
    for i in range(num_bins):
        if bin_dict[i]['count'] > 0:
            non_empty_bins += 1
        ace += abs(bin_dict[i]['bin_acc'] - bin_dict[i]['bin_conf'])
    return ace / float(non_empty_bins)

In [25]:
maximum_calibration_error(confs=[0.1,0.2,0.5,0.8], preds=[1,0,0,0], labels=[1,1,0,0], num_bins=2)

0.39999999999999997

In [27]:
average_calibration_error(confs=[0.1,0.2,0.5,0.8], preds=[1,0,0,0], labels=[1,1,0,0], num_bins=2)

0.29999999999999993

In [41]:
def test_classification_net_logits(logits, labels):
    '''
    This function reports classification accuracy and confusion matrix given logits and labels
    from a model.
    '''
    labels_list = []
    predictions_list = []
    confidence_vals_list = []

    softmax = F.softmax(logits, dim=1)
    confidence_vals, predictions = torch.max(softmax, dim=1)
    labels_list.extend(labels.cpu().numpy().tolist())
    predictions_list.extend(predictions.cpu().numpy().tolist())
    confidence_vals_list.extend(confidence_vals.cpu().numpy().tolist())
    accuracy = accuracy_score(labels_list, predictions_list)
    return confusion_matrix(labels_list, predictions_list), accuracy, labels_list,\
        predictions_list, confidence_vals_list

In [42]:
test_classification_net_logits(logits=torch.tensor([[0.1,0.6],[0.2,0.8],[0.5,0.1],[0.8,0.2]]), \
                               labels=torch.tensor([1,1,0,0]))

(array([[2, 0],
        [0, 2]]),
 1.0,
 [1, 1, 0, 0],
 [1, 1, 0, 0],
 [0.622459352016449,
  0.6456562876701355,
  0.5986876487731934,
  0.6456562876701355])

In [72]:
class ECELoss(nn.Module):
    '''
    Compute ECE (Expected Calibration Error)
    '''
    def __init__(self, n_bins=15):
        super(ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        
    def forward(self, logits, labels):
        
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(bin_boundaries[:-1], bin_boundaries[1:]):
            # Compute |confidence - accuracy| * (#bin / #samples) in each bin
            
            # Find the samples that belong to the bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            
            # If the bin isn't empty
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean() # Bin accuracy 
                avg_confidence_in_bin = confidences[in_bin].mean() # Bin confidence
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
                
        return ece

In [None]:
class AdaptiveECELoss(nn.Module):
    '''
    Compute Adaptive ECE
    '''
    def __init__(self, n_bins=15):
        super(AdaptiveECELoss, self).__init__()
        self.n_bins = n_bins

    def histedges_equalN(self, x):
        npt = len(x)
        return np.interp(np.linspace(0, npt, self.n_bins + 1),
                     np.arange(npt),
                     np.sort(x))
    def forward(self, logits, labels):
        
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        
        n, bin_boundaries = np.histogram(confidences.cpu().detach(), self.histedges_equalN(confidences.cpu().detach()))
        
        #print(n,confidences,bin_boundaries)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        return ece

In [74]:
logits=torch.tensor([[0.1,0.6],[0.2,0.8],[0.5,0.1],[0.8,0.2]])
labels=torch.tensor([1,1,0,0])

softmaxes = F.softmax(logits, dim=1)
confidences, predictions = torch.max(softmaxes, 1)
accuracies = predictions.eq(labels)
        
        