# Analyze BNN confidence on corrupted data

This experiment analyzes the confidence of BNN predictions on a corrupted data set.

In [None]:
import sys
sys.path.append("./../")

In [None]:
import numpy as np
import os, json

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import torch.nn.functional as F
from sklearn.calibration import calibration_curve
from sklearn.metrics import f1_score

import methods
import models
from transforms import normalize_x
import datasets

import matplotlib.pyplot as plt
plt.style.use('seaborn')

In [None]:
def normalize_and_blur(std=1.0):
    """
        Normalize and blur an image
    """
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.0], std=[1.0]),
        transforms.GaussianBlur(13, sigma=std)
    ])

In [None]:
def evaluate_model(model, testloader):
    pred_probs_mc = []
    targets = []
    predictions = []
    pred_labels = []

    with torch.no_grad():
        for data, target in testloader:
            mc_samples = []
            for mc_run in range(16):
                model.eval()
                output, _ = model.forward(data)
                #get probabilities from log-prob
                pred_probs = torch.exp(output)
                mc_samples.append(pred_probs.cpu().data.numpy())

            target_labels = target.cpu().data.numpy()
            pred_mean = np.mean(mc_samples, axis=0)
            Y_pred = np.argmax(pred_mean, axis=1)

            pred_probs_mc.append(np.stack(mc_samples, axis=1))
            targets.append(target_labels)
            predictions.append(pred_mean)
            pred_labels.append(Y_pred)

    # Stack all results
    pred_probs_mc = np.vstack(pred_probs_mc)
    targets = np.hstack(targets)
    predictions = np.vstack(predictions)
    pred_labels = np.hstack(pred_labels)
    
    return pred_labels, targets, predictions, pred_probs_mc

In [None]:
def compute_ece(prob_pred, prob_true, counts):
    """
        Compute expected calibration error
    prob_pred:  Mean predicted probability in each bin  (1d array)
    prob_true:  True proportion of samples in each bin (1d array)
    counts   :  No. of samples in each bin (1d array)
    """

    modified_counts = counts[counts > 0] # Keep only non-zero bins

    error = np.abs(prob_pred - prob_true)
    ece = np.sum(error * modified_counts) / np.sum(modified_counts)

    return ece

In [None]:
def analyze_model(model_dir, blur_levels = [1.0, 2.0, 3.0, 4.0, 5.0]):
    
    # Default paths
    config_json = os.path.join(model_dir, "config.json")
    ckpt_file = os.path.join(model_dir, "last.ckpt")
    config = json.load(open(config_json, 'r'))

    MethodClass = getattr(methods, config['method'])
    DatasetClass = getattr(datasets, config['dataset'])
    ModelClass = getattr(models, config['model'])
    TransformClass = normalize_x# getattr(transforms, config['transform'])
    
    model = MethodClass.load_from_checkpoint(
            os.path.join(model_dir, "last.ckpt"),
            model=ModelClass(2))
    
    fig, ax = plt.subplots(1, 1 + len(blur_levels), figsize=(6 * (1 + len(blur_levels)), 4))
    # -----------------------------------------------------------
    _ax = ax[0]
    testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=TransformClass())
    testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)
    
    pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
    acc = (targets == pred_labels).mean()
    print("Test accuracy: ", acc)
    
    # F1 score
    f1_sc = f1_score(targets, pred_labels)
    print("F1 Score :", f1_sc)

    pred_probs_mc = pred_probs_mc[:, :, 1]
    predictions = predictions[:, 1]
        
    err_idx = np.where(targets != pred_labels)[0]
    mc = pred_probs_mc.shape[1]
    unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)
    
    # Compute metrics
    n_bins = 10
    calib_prob_true, calib_prob_pred = calibration_curve(targets, predictions, n_bins=n_bins, strategy='uniform')
    scores_hist, score_hist_edges = np.histogram(predictions, bins=np.arange(0.0, 1.01, 1/n_bins))
    ece = compute_ece(calib_prob_pred, calib_prob_true, scores_hist)
    
    _ax.scatter(predictions, unc, s=10)
    _ax.scatter(predictions[err_idx], unc[err_idx], c='r', s=10)
    _ax.set_xlim(-0.05, 1.05)
    _ax.set_ylim(0.0, 0.15)
    _ax.set_xlabel("Predicted value of p")
    _ax.set_ylabel("Uncertainty")
    _ax.set_title("Clean data :: F1 = {:.4f}, ECE = {:.4f}".format(f1_sc, ece))
    
    # ------------------------------------------------------------
    for i, blur_std in enumerate(blur_levels):
        x_transform = normalize_and_blur(std=blur_std)
        testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=x_transform)
        testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)

        pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
        acc = (targets == pred_labels).mean()
        print("Test accuracy: ", acc)
        
        # F1 score
        f1_sc = f1_score(targets, pred_labels)
        print("F1 Score :", f1_sc)


        pred_probs_mc = pred_probs_mc[:, :, 1]
        predictions = predictions[:, 1]

        err_idx = np.where(targets != pred_labels)[0]
        mc = pred_probs_mc.shape[1]
        unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)
        
        calib_prob_true, calib_prob_pred = calibration_curve(targets, predictions, n_bins=n_bins, strategy='uniform')
        scores_hist, score_hist_edges = np.histogram(predictions, bins=np.arange(0.0, 1.01, 1/n_bins))
        ece = compute_ece(calib_prob_pred, calib_prob_true, scores_hist)

        
        _ax = ax[i+1]
        _ax.scatter(predictions, unc, s=10)
        _ax.scatter(predictions[err_idx], unc[err_idx], c='r', s=10)
        _ax.set_xlim(-0.05, 1.05)
        _ax.set_ylim(0.0, 0.15)
        _ax.set_xlabel("Predicted value of p")
        _ax.set_ylabel("Uncertainty")
        _ax.set_title("Blur-{} :: F1 = {:.4f}, ECE = {:.4f}".format(blur_std, f1_sc, ece))
        
    #--
#     lam_sl = config['method_params']['lam_sl']
#     alpha = config['method_params']['alpha']
#     plt.suptitle("$\\lambda_{} = {:1.0e} \\qquad    \\alpha = {:1.0e}$".format("{SL}", lam_sl, alpha))
    
    return fig

In [None]:
model_dir = "./../zoo/MedMNIST/ChestMNIST-infiltration/ConvNet/sl-auto-wt-1-20220408203309"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])

In [None]:
model_dir = "./../zoo/MedMNIST/ChestMNIST-infiltration/ConvNet/sl-wt-1-20220408135953"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])

In [None]:
model_dir = "./../zoo/MedMNIST/ChestMNIST-infiltration/ConvNet/sl-wt-1-20220408135357"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])