# Analyze BNN confidence on corrupted data

This experiment analyzes the confidence of BNN predictions on a corrupted data set.

In [None]:
import sys
sys.path.append("./../")

In [None]:
import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import torch.nn.functional as F

import models
from transforms import normalize_x
import datasets

import matplotlib.pyplot as plt
plt.style.use('seaborn')

In [None]:
def normalize_and_blur(std=1.0):
    """
        Normalize and blur an image
    """
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.0], std=[1.0]),
        transforms.GaussianBlur(13, sigma=std)
    ])

## Build dataset

In [None]:
x_transform = normalize_x()
trainset = datasets.BinaryMNISTC(53, 'identity', 'train', transform=normalize_x())
testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=normalize_x())

trainloader = DataLoader(dataset=trainset, batch_size=64, shuffle=True)
testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)

## Train model

In [None]:
model = models.LeNet(K=2)

In [None]:
opt = optim.Adam(model.parameters())

In [None]:
n_epochs = 5

In [None]:
for epoch in range(n_epochs):
    for batch_idx, (data, target) in enumerate(trainloader):
        opt.zero_grad()
        
        output_ = []
        
        for mc_run in range(16):
            output, kl = model(data)
            output_.append(output)
        
        output = torch.mean(torch.stack(output_), dim=0)
        nll_loss = F.nll_loss(output, target)
        
        #ELBO loss
        loss = nll_loss +  1 / kl / 64

        loss.backward()
        opt.step()

        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(trainloader.dataset),
                100. * batch_idx / len(trainloader), loss.item()))

## Evaluate model

In [None]:
def evaluate_model(model, testloader):
    pred_probs_mc = []
    targets = []
    predictions = []
    pred_labels = []

    with torch.no_grad():
        for data, target in testloader:
            mc_samples = []
            for mc_run in range(16):
                model.eval()
                output, _ = model.forward(data)
                #get probabilities from log-prob
                pred_probs = torch.exp(output)
                mc_samples.append(pred_probs.cpu().data.numpy())

            target_labels = target.cpu().data.numpy()
            pred_mean = np.mean(mc_samples, axis=0)
            Y_pred = np.argmax(pred_mean, axis=1)

            pred_probs_mc.append(np.stack(mc_samples, axis=1))
            targets.append(target_labels)
            predictions.append(pred_mean)
            pred_labels.append(Y_pred)

    # Stack all results
    pred_probs_mc = np.vstack(pred_probs_mc)
    targets = np.hstack(targets)
    predictions = np.vstack(predictions)
    pred_labels = np.hstack(pred_labels)
    
    return pred_labels, targets, predictions, pred_probs_mc

pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
print("Test accuracy: ", (targets == pred_labels).mean())

In [None]:
# Drop p's for label 0
pred_probs_mc = pred_probs_mc[:, :, 1]
predictions = predictions[:, 1]

## Plot distribution of predictions and MC sample variance

In [None]:
fig = plt.figure()

plt.hist(predictions)
plt.title("Distribution of p values")

In [None]:
mc_var = np.std(pred_probs_mc, axis=1)

fig = plt.figure()
plt.hist(mc_var)
plt.title("Distribution of MC variance")

## Outputs for false predictions

In [None]:
err_idx = np.where(targets != pred_labels)[0]
err_predictions = predictions[err_idx]
err_mc_samples = pred_probs_mc[err_idx]

In [None]:
n = len(err_idx)

fig, ax = plt.subplots(nrows=1, ncols=n, figsize=(n*2.5, 2.5))

for (i, _ax) in enumerate(ax):
    x = testset[err_idx[i]][0][0]
    _ax.imshow(x)
    _ax.set_title("s = {:.3f} +/- {:.3f}".format(err_predictions[i], err_mc_samples[i,:].std() / np.sqrt(err_mc_samples[i,:].shape[0])))

## Plot the predicted scores and their confidence

In [None]:
err_idx = np.where(targets != pred_labels)[0]

mc = pred_probs_mc.shape[1]
unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)

fig = plt.figure(figsize=(12, 8))
plt.scatter(predictions, unc, s=10)
plt.scatter(predictions[err_idx], unc[err_idx], c='r', s=50)
plt.xlim(0, 1)
plt.ylim(0, 0.1)

plt.xlabel("Predicted value of p")
plt.ylabel("Uncertainty")

## Test with blurred images

In [None]:
def analyze_bnn_model(model, blur_levels = [1.0, 2.0, 3.0, 4.0, 5.0]):
    
    fig, ax = plt.subplots(1, 1 + len(blur_levels), figsize=(6 * (1 + len(blur_levels)), 4))
    # -----------------------------------------------------------
    _ax = ax[0]
    testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=normalize_x())
    testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)
    
    pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
    acc = (targets == pred_labels).mean()
    print("Test accuracy: ", acc)

    pred_probs_mc = pred_probs_mc[:, :, 1]
    predictions = predictions[:, 1]
        
    err_idx = np.where(targets != pred_labels)[0]
    mc = pred_probs_mc.shape[1]
    unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)
    _ax.scatter(predictions, unc, s=10)
    _ax.scatter(predictions[err_idx], unc[err_idx], c='r', s=50)
    _ax.set_xlim(-0.05, 1.05)
    _ax.set_ylim(0.0, 0.10)
    _ax.set_xlabel("Predicted value of p")
    _ax.set_ylabel("Uncertainty")
    _ax.set_title("Clean data (Acc = {:.4f})".format(acc))
    
    # ------------------------------------------------------------
    for i, blur_std in enumerate(blur_levels):
        x_transform = normalize_and_blur(std=blur_std)
        testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=x_transform)
        testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)

        pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
        acc = (targets == pred_labels).mean()
        print("Test accuracy: ", acc)

        pred_probs_mc = pred_probs_mc[:, :, 1]
        predictions = predictions[:, 1]

        err_idx = np.where(targets != pred_labels)[0]
        mc = pred_probs_mc.shape[1]
        unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)
        
        _ax = ax[i+1]
        _ax.scatter(predictions, unc, s=10)
        _ax.scatter(predictions[err_idx], unc[err_idx], c='r', s=50)
        _ax.set_xlim(-0.05, 1.05)
        _ax.set_ylim(0.0, 0.10)
        _ax.set_xlabel("Predicted value of p")
        _ax.set_ylabel("Uncertainty")
        _ax.set_title("Blur {} (Acc = {:.4f})".format(blur_std, acc))
        
    #--
    return fig

In [None]:
fig = analyze_bnn_model(model, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])

In [None]:
x_transform = normalize_and_blur(std=5.0)
testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=x_transform)
testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)

In [None]:
n = 10
x, y = next(iter(testloader))
x = x[:n]
y = y[:n]

fig, ax = plt.subplots(nrows=1, ncols=n, figsize=(n*2.5, 2.5))

for (i, _ax) in enumerate(ax):
    _x = x[i, 0]
    _ax.imshow(_x)

In [None]:
pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
print("Test accuracy: ", (targets == pred_labels).mean())

pred_probs_mc = pred_probs_mc[:, :, 1]
predictions = predictions[:, 1]

In [None]:
err_idx = np.where(targets != pred_labels)[0]

mc = pred_probs_mc.shape[1]
unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)

fig = plt.figure(figsize=(12, 8))
plt.scatter(predictions, unc, s=10)
plt.scatter(predictions[err_idx], unc[err_idx], c='r', s=50)
plt.xlim(0, 1)
plt.ylim(0, 0.1)

plt.xlabel("Predicted value of p")
plt.ylabel("Uncertainty")

## Check SL model

In [None]:
import os, json
import methods

In [None]:
# model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e+00-alpha1e+02-nw-5-20220401055444"
# model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e-05-alpha1e+01-nw-1-20220401045543"
model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e-02-alpha1e+00-nw-5-20220401033018"

# Default paths
config_json = os.path.join(model_dir, "config.json")
ckpt_file = os.path.join(model_dir, "last.ckpt")

config = json.load(open(config_json, 'r'))

MethodClass = getattr(methods, config['method'])
DatasetClass = getattr(datasets, config['dataset'])
ModelClass = getattr(models, config['model'])
TransformClass = normalize_x# getattr(transforms, config['transform'])

In [None]:
model = MethodClass.load_from_checkpoint(
            os.path.join(model_dir, "last.ckpt"),
            model=ModelClass(2))

In [None]:
testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=TransformClass())
testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)

In [None]:
n = 10
x, y = next(iter(testloader))
x = x[:n]
y = y[:n]

fig, ax = plt.subplots(nrows=1, ncols=n, figsize=(n*2.5, 2.5))

for (i, _ax) in enumerate(ax):
    _x = x[i, 0]
    _ax.imshow(_x)

In [None]:
pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
print("Test accuracy: ", (targets == pred_labels).mean())

pred_probs_mc = pred_probs_mc[:, :, 1]
predictions = predictions[:, 1]

In [None]:
err_idx = np.where(targets != pred_labels)[0]

mc = pred_probs_mc.shape[1]
unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)

fig = plt.figure(figsize=(12, 8))
plt.scatter(predictions, unc, s=10)
plt.scatter(predictions[err_idx], unc[err_idx], c='r', s=50)
plt.xlim(-0.05, 1.05)
plt.ylim(0.0, 0.10)

plt.xlabel("Predicted value of p")
plt.ylabel("Uncertainty")

In [None]:
x_transform = normalize_and_blur(std=5.0)
testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=x_transform)
testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)

In [None]:
n = 10
x, y = next(iter(testloader))
x = x[:n]
y = y[:n]

fig, ax = plt.subplots(nrows=1, ncols=n, figsize=(n*2.5, 2.5))

for (i, _ax) in enumerate(ax):
    _x = x[i, 0]
    _ax.imshow(_x)

In [None]:
pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
print("Test accuracy: ", (targets == pred_labels).mean())

pred_probs_mc = pred_probs_mc[:, :, 1]
predictions = predictions[:, 1]

In [None]:
err_idx = np.where(targets != pred_labels)[0]

mc = pred_probs_mc.shape[1]
unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)

fig = plt.figure(figsize=(12, 8))
plt.scatter(predictions, unc, s=10)
plt.scatter(predictions[err_idx], unc[err_idx], c='r', s=50)
plt.xlim(-0.05, 1.05)
plt.ylim(-0.05, 0.15)

plt.xlabel("Predicted value of p")
plt.ylabel("Uncertainty")

In [None]:
def analyze_model(model_dir, blur_levels = [1.0, 2.0, 3.0, 4.0, 5.0]):
    
    # Default paths
    config_json = os.path.join(model_dir, "config.json")
    ckpt_file = os.path.join(model_dir, "last.ckpt")
    config = json.load(open(config_json, 'r'))

    MethodClass = getattr(methods, config['method'])
    DatasetClass = getattr(datasets, config['dataset'])
    ModelClass = getattr(models, config['model'])
    TransformClass = normalize_x# getattr(transforms, config['transform'])
    
    model = MethodClass.load_from_checkpoint(
            os.path.join(model_dir, "last.ckpt"),
            model=ModelClass(2))
    
    fig, ax = plt.subplots(1, 1 + len(blur_levels), figsize=(6 * (1 + len(blur_levels)), 4))
    # -----------------------------------------------------------
    _ax = ax[0]
    testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=TransformClass())
    testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)
    
    pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
    acc = (targets == pred_labels).mean()
    print("Test accuracy: ", acc)

    pred_probs_mc = pred_probs_mc[:, :, 1]
    predictions = predictions[:, 1]
        
    err_idx = np.where(targets != pred_labels)[0]
    mc = pred_probs_mc.shape[1]
    unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)
    _ax.scatter(predictions, unc, s=10)
    _ax.scatter(predictions[err_idx], unc[err_idx], c='r', s=50)
    _ax.set_xlim(-0.05, 1.05)
    _ax.set_ylim(0.0, 0.10)
    _ax.set_xlabel("Predicted value of p")
    _ax.set_ylabel("Uncertainty")
    _ax.set_title("Clean data (Acc = {:.4f})".format(acc))
    
    # ------------------------------------------------------------
    for i, blur_std in enumerate(blur_levels):
        x_transform = normalize_and_blur(std=blur_std)
        testset = datasets.BinaryMNISTC(53, 'identity', 'test', transform=x_transform)
        testloader = DataLoader(dataset=testset, batch_size=1024, shuffle=False)

        pred_labels, targets, predictions, pred_probs_mc = evaluate_model(model, testloader)
        acc = (targets == pred_labels).mean()
        print("Test accuracy: ", acc)

        pred_probs_mc = pred_probs_mc[:, :, 1]
        predictions = predictions[:, 1]

        err_idx = np.where(targets != pred_labels)[0]
        mc = pred_probs_mc.shape[1]
        unc = np.std(pred_probs_mc, axis=1) / np.sqrt(mc)
        
        _ax = ax[i+1]
        _ax.scatter(predictions, unc, s=10)
        _ax.scatter(predictions[err_idx], unc[err_idx], c='r', s=50)
        _ax.set_xlim(-0.05, 1.05)
        _ax.set_ylim(0.0, 0.10)
        _ax.set_xlabel("Predicted value of p")
        _ax.set_ylabel("Uncertainty")
        _ax.set_title("Blur {} (Acc = {:.4f})".format(blur_std, acc))
        
    #--
    lam_sl = config['method_params']['lam_sl']
    alpha = config['method_params']['alpha']
    plt.suptitle("$\\lambda_{} = {:1.0e} \\qquad    \\alpha = {:1.0e}$".format("{SL}", lam_sl, alpha))
    
    return fig

In [None]:
model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e-05-alpha1e+00-nw-3-20220401035110"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])

In [None]:
model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e-04-alpha1e+00-nw-1-20220401034326"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])

In [None]:
model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e-03-alpha1e+00-nw-5-20220401032602"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])

In [None]:
model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e-02-alpha1e+00-nw-1-20220401035306"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])

In [None]:
model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e-01-alpha1e+00-nw-1-20220401035756"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])

In [None]:
model_dir = "./../zoo/alpha-sl-scaling/BinaryMNISTC-53-identity/LeNet/sl-lam1e+00-alpha1e+00-nw-1-20220401040245"
fig = analyze_model(model_dir, blur_levels=[0.50, 1.0, 2.0, 3.0, 5.0])