In [1]:
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.autograd import Variable
import numpy as np
#import input_data
from sklearn.utils import shuffle as skshuffle
from math import *
from backpack import backpack, extend
from backpack.extensions import KFAC, DiagHessian, DiagGGNMC
from sklearn.metrics import roc_auc_score
import scipy
from tqdm import tqdm, trange
from bpjacext import NetJac
import pytest
from DirLPA_utils import * 
import time

import matplotlib.pyplot as plt

np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x7fed66446350>

In [2]:
def LPADirNN(num_classes=10, num_LL=256):
    
    features = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Conv2d(32, 64, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Flatten(),
        torch.nn.Linear(4 * 4 * 64, num_LL), #changed from 500
        torch.nn.Linear(num_LL, num_classes)  #changed from 500
    )
    return(features)

In [3]:
BATCH_SIZE_TRAIN_MNIST = 128
BATCH_SIZE_TEST_MNIST = 32
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [4]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=True,
        download=True,
        transform=MNIST_transform)

mnist_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform)

mnist_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [5]:
mnist_model = LPADirNN(num_LL=256)
loss_function = torch.nn.CrossEntropyLoss()

#mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=LR_TRAIN_MNIST)
mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "models/mnist_test_6iter_10c_simpleCNN_256.pth"

In [6]:
#Training routine

def train(model, train_loader, optimizer, max_iter, path, verbose=True):
    max_len = len(train_loader)

    for iter in range(max_iter):
        for batch_idx, (x, y) in enumerate(train_loader):
            output = model(x)

            accuracy = get_accuracy(output, y)

            loss = loss_function(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if verbose:
                print(
                    "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                    "Minibatch Loss %.3f  " % (loss) +
                    "Accuracy %.0f" % (accuracy * 100) + "%"
                )

    print("saving model at: {}".format(path))
    torch.save(mnist_model.state_dict(), path)

In [8]:
#train(mnist_model, mnist_train_loader, mnist_train_optimizer, MAX_ITER_MNIST, MNIST_PATH, verbose=True)

In [9]:
#predict in distribution
#MNIST_PATH = "models/mnist_test_6iter_10c_simpleNN_100.pth"
MNIST_PATH = "models/mnist_test_6iter_10c_simpleCNN_256.pth"

#mnist_model = LPADirNN(x=100)
mnist_model = LPADirNN(num_LL=256)
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

for batch_idx, (x, y) in enumerate(mnist_test_loader):
        max_len = int(np.ceil(len(mnist_test_loader.dataset)/BATCH_SIZE_TEST_MNIST))
        output = mnist_model(x)
        
        accuracy = get_accuracy(output, y)
        if batch_idx % 10 == 0:
            print(
                "Batch {}/{} \t".format(batch_idx, max_len) + 
                "Accuracy %.0f" % (accuracy * 100) + "%"
            )
        acc.append(accuracy)
    
avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: models/mnist_test_6iter_10c_simpleCNN_256.pth
Batch 0/313 	Accuracy 100%
Batch 10/313 	Accuracy 94%
Batch 20/313 	Accuracy 97%
Batch 30/313 	Accuracy 100%
Batch 40/313 	Accuracy 97%
Batch 50/313 	Accuracy 97%
Batch 60/313 	Accuracy 100%
Batch 70/313 	Accuracy 97%
Batch 80/313 	Accuracy 97%
Batch 90/313 	Accuracy 97%
Batch 100/313 	Accuracy 100%
Batch 110/313 	Accuracy 100%
Batch 120/313 	Accuracy 97%
Batch 130/313 	Accuracy 97%
Batch 140/313 	Accuracy 91%
Batch 150/313 	Accuracy 100%
Batch 160/313 	Accuracy 100%
Batch 170/313 	Accuracy 100%
Batch 180/313 	Accuracy 100%
Batch 190/313 	Accuracy 94%
Batch 200/313 	Accuracy 100%
Batch 210/313 	Accuracy 100%
Batch 220/313 	Accuracy 100%
Batch 230/313 	Accuracy 100%
Batch 240/313 	Accuracy 100%
Batch 250/313 	Accuracy 100%
Batch 260/313 	Accuracy 97%
Batch 270/313 	Accuracy 100%
Batch 280/313 	Accuracy 100%
Batch 290/313 	Accuracy 100%
Batch 300/313 	Accuracy 94%
Batch 310/313 	Accuracy 100%
overall test accuracy on MNIST

In [10]:
## play around with Backpack
def get_Hessian_NN(model, train_loader, var0, device='cpu', verbose=True):
    lossfunc = torch.nn.CrossEntropyLoss()

    extend(lossfunc, debug=False)
    extend(model, debug=False)

    Hessian_diag = []
    for param in mnist_model.parameters():
        ps = param.size()
        print("parameter size: ", ps)
        Hessian_diag.append(torch.zeros(ps, device=device))
        #print(param.numel())

    tau = 1/var0
    max_len = len(train_loader)

    with backpack(DiagHessian()):

        for batch_idx, (x, y) in enumerate(train_loader):

            if device == 'cuda':
                x, y = x.float().cuda(), y.long().cuda()

            mnist_model.zero_grad()
            lossfunc(model(x), y).backward()

            with torch.no_grad():
                # Hessian of weight
                for idx, param in enumerate(model.parameters()):

                    H_ = param.diag_h
                    #add bias here
                    H_ += tau * torch.ones(H_.size())

                    rho = min(1-1/(batch_idx+1), 0.95)

                    Hessian_diag[idx] = rho*Hessian_diag[idx] + (1-rho)*H_
            
            if verbose:
                print("Batch: {}/{}".format(batch_idx, max_len))

    #combine all elements of the Hessian to one big vector
    Hessian_diag = torch.cat([el.view(-1) for el in Hessian_diag])
    print("Hessian_size: ", Hessian_diag.size())
    num_params = np.sum([p.numel() for p in model.parameters()])
    assert(num_params == Hessian_diag.size(-1))
    return(Hessian_diag)
        

In [90]:
Hessian_MNIST = get_Hessian_NN(model=mnist_model, train_loader=mnist_train_loader, var0=200, verbose=False)

parameter size:  torch.Size([32, 1, 5, 5])
parameter size:  torch.Size([32])
parameter size:  torch.Size([64, 32, 5, 5])
parameter size:  torch.Size([64])
parameter size:  torch.Size([256, 1024])
parameter size:  torch.Size([256])
parameter size:  torch.Size([10, 256])
parameter size:  torch.Size([10])
Hessian_size:  torch.Size([317066])


In [95]:
def compute_jacobians_with_backpack(model, x, y, lossfunc):
    """
    Returns the jacobians of the network

    The output is a list. Each element in the list is a tensor
    corresponding to the model.parameters().

    The tensor are of the form [N, *, C] where N is the batch dimension,
    C is the number of classes (output size of the network)
    and * is the shape of the model parameters
    """
    loss = lossfunc(model(x), y)

    with backpack(NetJac()):
        loss.backward()

    jacs = []
    for p in model.parameters():
        jacs.append(p.netjacs.data)
    return jacs

def transform2full_jac(backpack_jacobian):

    jac_full = []
    #batch_size
    N = backpack_jacobian[0].size(0)
    #num classes
    k = backpack_jacobian[0].size(-1)
    for j in backpack_jacobian:
        jac_full.append(j.view(N, -1, k).permute(0,2,1))
    jac_full = torch.cat(jac_full, dim=-1)
    return(jac_full)

def get_Jacobian(model, x, y, lossfunc):
    return(transform2full_jac(compute_jacobians_with_backpack(model, x, y, lossfunc)))

In [97]:
print(Hessian_MNIST)

tensor([0.0077, 0.0077, 0.0074,  ..., 0.0064, 0.0077, 0.0074])


In [98]:
def predict_Diagonal_full(model, test_loader, Hessian, verbose=True, num_samples=100, cuda=False, timing=False):
    
    lossfunc = torch.nn.CrossEntropyLoss()
    extend(lossfunc, debug=False)
    
    py = []
    if timing:
        time_sum = 0
    
    max_len = len(test_loader)
    for batch_idx, (x, y) in enumerate(test_loader):
        
        if cuda:
            x, y = x.cuda(), y.cuda()
        
        J = get_Jacobian(model, x, y, lossfunc)
        batch_size = J.size(0)
        num_classes = J.size(1)
        Cov_pred = torch.bmm(J * Hessian, J.permute(0, 2, 1))
        if verbose:
            print("Jacobian size: ", J.size())
            print("cov pred size: ", Cov_pred.size())
        
        mu_pred = model(x)
        post_pred = MultivariateNormal(mu_pred, Cov_pred)

        # MC-integral
        t0 = time.process_time()
        py_ = 0

        for _ in range(num_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)


        py_ /= num_samples
        py_ = py_.detach()

        py.append(py_)
        t1 = time.process_time()
        if timing:
            time_sum += (t1-t0)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))
    
    if timing:
        print("total time used for transform: {:.05f}".format(time_sum))

    return torch.cat(py, dim=0)

In [99]:
BATCH_SIZE_TEST_FMNIST = 32
BATCH_SIZE_TEST_KMNIST = 32

In [100]:
FMNIST_test = torchvision.datasets.FashionMNIST(
        '~/data/fmnist', train=False, download=False,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [101]:
KMNIST_test = torchvision.datasets.KMNIST(
        '~/data/kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=BATCH_SIZE_TEST_KMNIST, shuffle=False)

In [102]:
"""Load notMNIST"""

import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from matplotlib.pyplot import imread
from torch import Tensor

"""
Loads the train/test set. 
Every image in the dataset is 28x28 pixels and the labels are numbered from 0-9
for A-J respectively.
Set root to point to the Train/Test folders.
"""

# Creating a sub class of torch.utils.data.dataset.Dataset
class notMNIST(Dataset):

    # The init method is called when this class will be instantiated
    def __init__(self, root, transform):
        
        #super(notMNIST, self).__init__(root, transform=transform)

        self.transform = transform
        
        Images, Y = [], []
        folders = os.listdir(root)

        for folder in folders:
            folder_path = os.path.join(root, folder)
            for ims in os.listdir(folder_path):
                try:
                    img_path = os.path.join(folder_path, ims)
                    Images.append(np.array(imread(img_path)))
                    Y.append(ord(folder) - 65)  # Folders are A-J so labels will be 0-9
                except:
                    # Some images in the dataset are damaged
                    print("File {}/{} is broken".format(folder, ims))
        data = [(x, y) for x, y in zip(Images, Y)]
        self.data = data
        self.targets = torch.Tensor(Y)

    # The number of items in the dataset
    def __len__(self):
        return len(self.data)

    # The Dataloader is a generator that repeatedly calls the getitem method.
    # getitem is supposed to return (X, Y) for the specified index.
    def __getitem__(self, index):
        img = self.data[index][0]

        if self.transform is not None:
            img = self.transform(img)
            
        # Input for Conv2D should be Channels x Height x Width
        img_tensor = Tensor(img).view(1, 28, 28).float()
        label = self.data[index][1]
        return (img_tensor, label)

In [103]:
#root = os.path.abspath('~/data')
root = os.path.expanduser('~/data')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
not_mnist_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=32,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


In [104]:
def get_in_dist_values(py_in, targets):
    acc_in = np.mean(np.argmax(py_in, 1) == targets)
    prob_correct = np.choose(targets, py_in.T).mean()
    average_entropy = -np.sum(py_in*np.log(py_in+1e-8), axis=1).mean()
    MMC = py_in.max(1).mean()
    return(acc_in, prob_correct, average_entropy, MMC)
    
def get_out_dist_values(py_in, py_out, targets):
    average_entropy = -np.sum(py_out*np.log(py_out+1e-8), axis=1).mean()
    acc_out = np.mean(np.argmax(py_out, 1) == targets)
    prob_correct = np.choose(targets, py_out.T).mean()
    labels = np.zeros(len(py_in)+len(py_out), dtype='int32')
    labels[:len(py_in)] = 1
    examples = np.concatenate([py_in.max(1), py_out.max(1)])
    auroc = roc_auc_score(labels, examples)
    MMC = py_out.max(1).mean()
    return(acc_out, prob_correct, average_entropy, MMC, auroc)

def print_in_dist_values(acc_in, prob_correct, average_entropy, MMC, train='mnist', method='LLLA-KF'):
    
    print(f'[In, {method}, {train}] Accuracy: {acc_in:.3f}; average entropy: {average_entropy:.3f}; \
    MMC: {MMC:.3f}; Prob @ correct: {prob_correct:.3f}')


def print_out_dist_values(acc_out, prob_correct, average_entropy, MMC, auroc, train='mnist', test='FMNIST', method='LLLA-KF'):
   
    print(f'[Out-{test}, {method}, {train}] Accuracy: {acc_out:.3f}; Average entropy: {average_entropy:.3f};\
    MMC: {MMC:.3f}; AUROC: {auroc:.3f}; Prob @ correct: {prob_correct:.3f}')

# MAP estimate

In [105]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [106]:
mnist_test_in_MAP = predict_MAP(mnist_model, mnist_test_loader).numpy()
mnist_test_out_fmnist_MAP = predict_MAP(mnist_model, FMNIST_test_loader).numpy()
mnist_test_out_notMNIST_MAP = predict_MAP(mnist_model, not_mnist_test_loader).numpy()
mnist_test_out_KMNIST_MAP = predict_MAP(mnist_model, KMNIST_test_loader).numpy()

In [107]:
acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP = get_in_dist_values(mnist_test_in_MAP, targets)
acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_fmnist_MAP, targets_FMNIST)
acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_notMNIST_MAP, targets_notMNIST)
acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_KMNIST_MAP, targets_KMNIST)

In [108]:
print_in_dist_values(acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP, 'mnist', 'MAP')
print_out_dist_values(acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP, 'FMNIST', 'MAP')
print_out_dist_values(acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP, 'notMNIST', 'MAP')
print_out_dist_values(acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP, 'KMNIST', 'MAP')

[In, MAP, mnist] Accuracy: 0.990; average entropy: 0.041;     MMC: 0.987; Prob @ correct: 0.983
[Out-MAP, LLLA-KF, FMNIST] Accuracy: 0.075; Average entropy: 1.387;    MMC: 0.512; AUROC: 0.989; Prob @ correct: 0.097
[Out-MAP, LLLA-KF, notMNIST] Accuracy: 0.145; Average entropy: 0.844;    MMC: 0.696; AUROC: 0.964; Prob @ correct: 0.130
[Out-MAP, LLLA-KF, KMNIST] Accuracy: 0.082; Average entropy: 0.849;    MMC: 0.693; AUROC: 0.968; Prob @ correct: 0.081


# Diag Hessian Sampling estimate

In [109]:
mnist_test_in_D = predict_Diagonal_full(mnist_model, mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=100).numpy()
mnist_test_out_FMNIST_D = predict_Diagonal_full(mnist_model, FMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=100).numpy()
mnist_test_out_notMNIST_D = predict_Diagonal_full(mnist_model, not_mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=100).numpy()
mnist_test_out_KMNIST_D = predict_Diagonal_full(mnist_model, KMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=100).numpy()

In [110]:
acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D = get_in_dist_values(mnist_test_in_D, targets)
acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_FMNIST_D, targets_FMNIST)
acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_notMNIST_D, targets_notMNIST)
acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_KMNIST_D, targets_KMNIST)

In [111]:
print_in_dist_values(acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D, 'mnist', 'Diag')
print_out_dist_values(acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D, test='fmnist', method='Diag')
print_out_dist_values(acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D, test='notMNIST', method='Diag')
print_out_dist_values(acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D, test='KMNIST', method='Diag')

[In, Diag, mnist] Accuracy: 0.989; average entropy: 0.278;     MMC: 0.927; Prob @ correct: 0.924
[Out-fmnist, Diag, mnist] Accuracy: 0.088; Average entropy: 1.686;    MMC: 0.395; AUROC: 0.986; Prob @ correct: 0.106
[Out-notMNIST, Diag, mnist] Accuracy: 0.134; Average entropy: 1.345;    MMC: 0.517; AUROC: 0.966; Prob @ correct: 0.130
[Out-KMNIST, Diag, mnist] Accuracy: 0.081; Average entropy: 1.327;    MMC: 0.517; AUROC: 0.964; Prob @ correct: 0.084


# Dirichlet Laplace Approximation

In [112]:
def get_alpha_from_Normal(mu, Sigma):
    batch_size, K = mu.size(0), mu.size(-1)
    Sigma_d = torch.diagonal(Sigma, dim1=1, dim2=2)
    sum_exp = torch.sum(torch.exp(-1*torch.Tensor(mu)), dim=1).view(-1,1)
    alpha = 1/Sigma_d * (1 - 2/K + torch.exp(mu)/K**2 * sum_exp)
    
    assert(alpha.size() == mu.size())
    
    return(alpha)

In [113]:
def predict_DIR_LPA(model, test_loader, Hessian, verbose=True, cuda=False, timing=False):

    lossfunc = torch.nn.CrossEntropyLoss()
    extend(lossfunc, debug=False)
    
    alphas = []
    if timing:
        time_sum = 0

    max_len = len(test_loader)
    for batch_idx, (x, y) in enumerate(test_loader):
        
        if cuda:
            x, y = x.cuda, y.cuda()

        J = get_Jacobian(model, x, y, lossfunc)
        batch_size = J.size(0)
        num_classes = J.size(1)
        Cov_pred = torch.bmm(J * Hessian, J.permute(0, 2, 1))
        
        mu_pred = model(x)
        
        t0 = time.process_time()
        alpha = get_alpha_from_Normal(mu_pred, Cov_pred).detach()
        t1 = time.process_time()
        if timing:
            time_sum += (t1 - t0)

        alphas.append(alpha)


        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))
    
    if timing:
        print("total time used for transform: {:.05f}".format(time_sum))

    return(torch.cat(alphas, dim = 0))


In [115]:
mnist_test_in_DIR_LPA = predict_DIR_LPA(mnist_model, mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).numpy()
mnist_test_out_FMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, FMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).numpy()
mnist_test_out_notMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, not_mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).numpy()
mnist_test_out_KMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, KMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).numpy()

In [116]:
mnist_test_in_DIR_LPAn = mnist_test_in_DIR_LPA/mnist_test_in_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_FMNIST_DIR_LPAn = mnist_test_out_FMNIST_DIR_LPA/mnist_test_out_FMNIST_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_notMNIST_DIR_LPAn = mnist_test_out_notMNIST_DIR_LPA/mnist_test_out_notMNIST_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_KMNIST_DIR_LPAn = mnist_test_out_KMNIST_DIR_LPA/mnist_test_out_KMNIST_DIR_LPA.sum(1).reshape(-1,1)

In [117]:
acc_in_DIR_LPAn, prob_correct_in_DIR_LPAn, ent_in_DIR_LPAn, MMC_in_DIR_LPAn = get_in_dist_values(mnist_test_in_DIR_LPAn, targets)
acc_out_FMNIST_DIR_LPAn, prob_correct_out_FMNIST_DIR_LPAn, ent_out_FMNIST_DIR_LPAn, MMC_out_FMNIST_DIR_LPAn, auroc_out_FMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_FMNIST_DIR_LPAn, targets_FMNIST)
acc_out_notMNIST_DIR_LPAn, prob_correct_out_notMNIST_DIR_LPAn, ent_out_notMNIST_DIR_LPAn, MMC_out_notMNIST_DIR_LPAn, auroc_out_notMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_notMNIST_DIR_LPAn, targets_notMNIST)
acc_out_KMNIST_DIR_LPAn, prob_correct_out_KMNIST_DIR_LPAn, ent_out_KMNIST_DIR_LPAn, MMC_out_KMNIST_DIR_LPAn, auroc_out_KMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_KMNIST_DIR_LPAn, targets_KMNIST)

In [118]:
print_in_dist_values(acc_in_DIR_LPAn, prob_correct_in_DIR_LPAn, ent_in_DIR_LPAn, MMC_in_DIR_LPAn, 'mnist', 'DIR_LPAn')
print_out_dist_values(acc_out_FMNIST_DIR_LPAn, prob_correct_out_FMNIST_DIR_LPAn, ent_out_FMNIST_DIR_LPAn, MMC_out_FMNIST_DIR_LPAn, auroc_out_FMNIST_DIR_LPAn, test='fmnist', method='DIR_LPAn')
print_out_dist_values(acc_out_notMNIST_DIR_LPAn, prob_correct_out_notMNIST_DIR_LPAn, ent_out_notMNIST_DIR_LPAn, MMC_out_notMNIST_DIR_LPAn, auroc_out_notMNIST_DIR_LPAn, test='notMNIST', method='DIR_LPAn')
print_out_dist_values(acc_out_KMNIST_DIR_LPAn, prob_correct_out_KMNIST_DIR_LPAn, ent_out_KMNIST_DIR_LPAn, MMC_out_KMNIST_DIR_LPAn, auroc_out_KMNIST_DIR_LPAn, test='KMNIST', method='DIR_LPAn')

[In, DIR_LPAn, mnist] Accuracy: 0.989; average entropy: 0.057;     MMC: 0.984; Prob @ correct: 0.979
[Out-fmnist, DIR_LPAn, mnist] Accuracy: 0.058; Average entropy: 1.786;    MMC: 0.369; AUROC: 0.991; Prob @ correct: 0.083
[Out-notMNIST, DIR_LPAn, mnist] Accuracy: 0.134; Average entropy: 1.084;    MMC: 0.630; AUROC: 0.968; Prob @ correct: 0.135
[Out-KMNIST, DIR_LPAn, mnist] Accuracy: 0.082; Average entropy: 1.059;    MMC: 0.635; AUROC: 0.966; Prob @ correct: 0.080


# additional Calculations for the Dirichlet

In [119]:
from scipy.special import digamma, loggamma

def beta_function(alpha):
    return(np.exp(np.sum([loggamma(a_i) for a_i in alpha]) - loggamma(np.sum(alpha))))

def alphas_norm(alphas):
    alphas = np.array(alphas)
    return(alphas/alphas.sum(axis=1).reshape(-1,1))

def alphas_variance(alphas):
    alphas = np.array(alphas)
    norm = alphas_norm(alphas)
    nom = norm * (1 - norm)
    den = alphas.sum(axis=1).reshape(-1,1) + 1
    return(nom/den)

def log_beta_function(alpha):
    return(np.sum([loggamma(a_i) for a_i in alpha]) - loggamma(np.sum(alpha)))

def alphas_entropy(alphas):
    K = len(alphas[0])
    alphas = np.array(alphas)
    entropy = []
    for x in alphas:
        B = log_beta_function(x)
        alpha_0 = np.sum(x)
        C = (alpha_0 - K)*digamma(alpha_0)
        D = np.sum((x-1)*digamma(x))
        entropy.append(B + C - D)
    
    return(np.array(entropy))
        

def alphas_log_prob(alphas):
    alphas = np.array(alphas)
    dig_sum = digamma(alphas.sum(axis=1).reshape(-1,1))
    log_prob = digamma(alphas) - dig_sum
    return(log_prob)

def auroc_entropy(alphas_in, alphas_out):
    
    entropy_in = alphas_entropy(alphas_in)
    entropy_out = alphas_entropy(alphas_out)
    labels = np.zeros(len(entropy_in)+len(entropy_out), dtype='int32')
    labels[:len(entropy_in)] = 1
    examples = np.concatenate([entropy_in, entropy_out])
    auroc_ent = roc_auc_score(labels, examples)
    return(auroc_ent)

def auroc_variance(alphas_in, alphas_out, method='mean'):
    
    if method=='mean':
        variance_in = alphas_variance(alphas_in).mean(1)
        variance_out = alphas_variance(alphas_out).mean(1)
    elif method=='max':
        variance_in = alphas_variance(alphas_in).max(1)
        variance_out = alphas_variance(alphas_out).max(1)
    labels = np.zeros(len(variance_in)+len(variance_out), dtype='int32')
    labels[:len(variance_in)] = 1
    examples = np.concatenate([variance_in, variance_out])
    auroc_ent = roc_auc_score(labels, examples)
    return(auroc_ent)

In [120]:
print("auroc entropy: MNIST in, FMNIST out: ", 1 - auroc_entropy(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_FMNIST_DIR_LPA))
print("auroc entropy: MNIST in, notMNIST out: ", 1 - auroc_entropy(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_notMNIST_DIR_LPA))
print("auroc entropy: MNIST in, KMNIST out: ", 1 - auroc_entropy(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_KMNIST_DIR_LPA))

auroc entropy: MNIST in, FMNIST out:  0.87893859
auroc entropy: MNIST in, notMNIST out:  0.8378437486648151
auroc entropy: MNIST in, KMNIST out:  0.83759107


In [121]:
print("auroc variance: MNIST in, FMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_FMNIST_DIR_LPA, method='mean'))
print("auroc variance: MNIST in, notMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_notMNIST_DIR_LPA, method='mean'))
print("auroc variance: MNIST in, KMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_KMNIST_DIR_LPA, method='mean'))

auroc variance: MNIST in, FMNIST out:  0.99291131
auroc variance: MNIST in, notMNIST out:  0.968292036957915
auroc variance: MNIST in, KMNIST out:  0.95995179


In [122]:
print("auroc variance: MNIST in, FMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_FMNIST_DIR_LPA, method='max'))
print("auroc variance: MNIST in, notMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_notMNIST_DIR_LPA, method='max'))
print("auroc variance: MNIST in, KMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_KMNIST_DIR_LPA, method='max'))

auroc variance: MNIST in, FMNIST out:  0.99238503
auroc variance: MNIST in, notMNIST out:  0.9680424481948302
auroc variance: MNIST in, KMNIST out:  0.95961684
