In [1]:
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.autograd import Variable
import numpy as np
#import input_data
from sklearn.utils import shuffle as skshuffle
from math import *
from backpack import backpack, extend
from backpack.extensions import KFAC, DiagHessian, DiagGGNMC
from sklearn.metrics import roc_auc_score
import scipy
from tqdm import tqdm, trange
from bpjacext import NetJac
import pytest
from DirLPA_utils import * 
import time

import matplotlib.pyplot as plt

s = 127
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed(s)

In [2]:
def LPADirNN(num_classes=10, num_LL=256):
    
    features = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Conv2d(32, 64, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Flatten(),
        torch.nn.Linear(4 * 4 * 64, num_LL), #changed from 500
        torch.nn.Linear(num_LL, num_classes)  #changed from 500
    )
    return(features)

In [3]:
BATCH_SIZE_TRAIN_MNIST = 128
BATCH_SIZE_TEST_MNIST = 128
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [4]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=True,
        download=True,
        transform=MNIST_transform)

mnist_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform)

mnist_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [5]:
mnist_model = LPADirNN(num_LL=256).cuda()
loss_function = torch.nn.CrossEntropyLoss()

#mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=LR_TRAIN_MNIST)
mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "weights/mnist_test_6iter_10c_simpleCNN_256.pth"

In [6]:
#Training routine

def train(model, train_loader, optimizer, max_iter, path, verbose=True):
    max_len = len(train_loader)

    for iter in range(max_iter):
        for batch_idx, (x, y) in enumerate(train_loader):
            
            x, y = x.cuda(), y.cuda()
            
            output = model(x)

            accuracy = get_accuracy(output, y)

            loss = loss_function(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if verbose:
                print(
                    "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                    "Minibatch Loss %.3f  " % (loss) +
                    "Accuracy %.0f" % (accuracy * 100) + "%"
                )

    print("saving model at: {}".format(path))
    torch.save(mnist_model.state_dict(), path)

In [7]:
train(mnist_model, mnist_train_loader, mnist_train_optimizer, MAX_ITER_MNIST, MNIST_PATH, verbose=True)

Iteration 0; 0/469 	Minibatch Loss 2.307  Accuracy 9%
Iteration 0; 1/469 	Minibatch Loss 2.283  Accuracy 7%
Iteration 0; 2/469 	Minibatch Loss 2.167  Accuracy 20%
Iteration 0; 3/469 	Minibatch Loss 2.152  Accuracy 48%
Iteration 0; 4/469 	Minibatch Loss 2.070  Accuracy 45%
Iteration 0; 5/469 	Minibatch Loss 1.927  Accuracy 66%
Iteration 0; 6/469 	Minibatch Loss 1.728  Accuracy 68%
Iteration 0; 7/469 	Minibatch Loss 1.643  Accuracy 59%
Iteration 0; 8/469 	Minibatch Loss 1.483  Accuracy 57%
Iteration 0; 9/469 	Minibatch Loss 1.423  Accuracy 56%
Iteration 0; 10/469 	Minibatch Loss 1.223  Accuracy 64%
Iteration 0; 11/469 	Minibatch Loss 1.117  Accuracy 80%
Iteration 0; 12/469 	Minibatch Loss 0.980  Accuracy 80%
Iteration 0; 13/469 	Minibatch Loss 0.936  Accuracy 72%
Iteration 0; 14/469 	Minibatch Loss 0.793  Accuracy 80%
Iteration 0; 15/469 	Minibatch Loss 0.638  Accuracy 81%
Iteration 0; 16/469 	Minibatch Loss 0.583  Accuracy 80%
Iteration 0; 17/469 	Minibatch Loss 0.526  Accuracy 82%
Iter

Iteration 0; 151/469 	Minibatch Loss 0.150  Accuracy 97%
Iteration 0; 152/469 	Minibatch Loss 0.098  Accuracy 98%
Iteration 0; 153/469 	Minibatch Loss 0.086  Accuracy 97%
Iteration 0; 154/469 	Minibatch Loss 0.155  Accuracy 97%
Iteration 0; 155/469 	Minibatch Loss 0.144  Accuracy 94%
Iteration 0; 156/469 	Minibatch Loss 0.070  Accuracy 98%
Iteration 0; 157/469 	Minibatch Loss 0.102  Accuracy 97%
Iteration 0; 158/469 	Minibatch Loss 0.120  Accuracy 97%
Iteration 0; 159/469 	Minibatch Loss 0.102  Accuracy 95%
Iteration 0; 160/469 	Minibatch Loss 0.066  Accuracy 98%
Iteration 0; 161/469 	Minibatch Loss 0.079  Accuracy 98%
Iteration 0; 162/469 	Minibatch Loss 0.125  Accuracy 96%
Iteration 0; 163/469 	Minibatch Loss 0.142  Accuracy 96%
Iteration 0; 164/469 	Minibatch Loss 0.134  Accuracy 96%
Iteration 0; 165/469 	Minibatch Loss 0.135  Accuracy 97%
Iteration 0; 166/469 	Minibatch Loss 0.043  Accuracy 99%
Iteration 0; 167/469 	Minibatch Loss 0.147  Accuracy 95%
Iteration 0; 168/469 	Minibatch

Iteration 0; 303/469 	Minibatch Loss 0.100  Accuracy 98%
Iteration 0; 304/469 	Minibatch Loss 0.075  Accuracy 97%
Iteration 0; 305/469 	Minibatch Loss 0.118  Accuracy 95%
Iteration 0; 306/469 	Minibatch Loss 0.105  Accuracy 95%
Iteration 0; 307/469 	Minibatch Loss 0.046  Accuracy 98%
Iteration 0; 308/469 	Minibatch Loss 0.086  Accuracy 97%
Iteration 0; 309/469 	Minibatch Loss 0.044  Accuracy 98%
Iteration 0; 310/469 	Minibatch Loss 0.057  Accuracy 98%
Iteration 0; 311/469 	Minibatch Loss 0.031  Accuracy 98%
Iteration 0; 312/469 	Minibatch Loss 0.175  Accuracy 96%
Iteration 0; 313/469 	Minibatch Loss 0.091  Accuracy 98%
Iteration 0; 314/469 	Minibatch Loss 0.123  Accuracy 98%
Iteration 0; 315/469 	Minibatch Loss 0.045  Accuracy 99%
Iteration 0; 316/469 	Minibatch Loss 0.071  Accuracy 98%
Iteration 0; 317/469 	Minibatch Loss 0.070  Accuracy 98%
Iteration 0; 318/469 	Minibatch Loss 0.085  Accuracy 98%
Iteration 0; 319/469 	Minibatch Loss 0.071  Accuracy 97%
Iteration 0; 320/469 	Minibatch

Iteration 0; 455/469 	Minibatch Loss 0.093  Accuracy 97%
Iteration 0; 456/469 	Minibatch Loss 0.067  Accuracy 98%
Iteration 0; 457/469 	Minibatch Loss 0.089  Accuracy 98%
Iteration 0; 458/469 	Minibatch Loss 0.027  Accuracy 99%
Iteration 0; 459/469 	Minibatch Loss 0.049  Accuracy 97%
Iteration 0; 460/469 	Minibatch Loss 0.056  Accuracy 98%
Iteration 0; 461/469 	Minibatch Loss 0.044  Accuracy 99%
Iteration 0; 462/469 	Minibatch Loss 0.038  Accuracy 98%
Iteration 0; 463/469 	Minibatch Loss 0.083  Accuracy 96%
Iteration 0; 464/469 	Minibatch Loss 0.062  Accuracy 98%
Iteration 0; 465/469 	Minibatch Loss 0.079  Accuracy 95%
Iteration 0; 466/469 	Minibatch Loss 0.021  Accuracy 99%
Iteration 0; 467/469 	Minibatch Loss 0.047  Accuracy 99%
Iteration 0; 468/469 	Minibatch Loss 0.073  Accuracy 97%
Iteration 1; 0/469 	Minibatch Loss 0.045  Accuracy 99%
Iteration 1; 1/469 	Minibatch Loss 0.046  Accuracy 98%
Iteration 1; 2/469 	Minibatch Loss 0.031  Accuracy 99%
Iteration 1; 3/469 	Minibatch Loss 0.

Iteration 1; 138/469 	Minibatch Loss 0.030  Accuracy 98%
Iteration 1; 139/469 	Minibatch Loss 0.081  Accuracy 98%
Iteration 1; 140/469 	Minibatch Loss 0.029  Accuracy 99%
Iteration 1; 141/469 	Minibatch Loss 0.033  Accuracy 99%
Iteration 1; 142/469 	Minibatch Loss 0.010  Accuracy 100%
Iteration 1; 143/469 	Minibatch Loss 0.016  Accuracy 99%
Iteration 1; 144/469 	Minibatch Loss 0.065  Accuracy 98%
Iteration 1; 145/469 	Minibatch Loss 0.015  Accuracy 99%
Iteration 1; 146/469 	Minibatch Loss 0.022  Accuracy 99%
Iteration 1; 147/469 	Minibatch Loss 0.010  Accuracy 100%
Iteration 1; 148/469 	Minibatch Loss 0.049  Accuracy 99%
Iteration 1; 149/469 	Minibatch Loss 0.077  Accuracy 98%
Iteration 1; 150/469 	Minibatch Loss 0.094  Accuracy 96%
Iteration 1; 151/469 	Minibatch Loss 0.075  Accuracy 98%
Iteration 1; 152/469 	Minibatch Loss 0.016  Accuracy 99%
Iteration 1; 153/469 	Minibatch Loss 0.061  Accuracy 98%
Iteration 1; 154/469 	Minibatch Loss 0.036  Accuracy 99%
Iteration 1; 155/469 	Minibat

Iteration 1; 290/469 	Minibatch Loss 0.036  Accuracy 98%
Iteration 1; 291/469 	Minibatch Loss 0.083  Accuracy 98%
Iteration 1; 292/469 	Minibatch Loss 0.045  Accuracy 99%
Iteration 1; 293/469 	Minibatch Loss 0.064  Accuracy 98%
Iteration 1; 294/469 	Minibatch Loss 0.069  Accuracy 98%
Iteration 1; 295/469 	Minibatch Loss 0.066  Accuracy 98%
Iteration 1; 296/469 	Minibatch Loss 0.034  Accuracy 99%
Iteration 1; 297/469 	Minibatch Loss 0.086  Accuracy 95%
Iteration 1; 298/469 	Minibatch Loss 0.064  Accuracy 98%
Iteration 1; 299/469 	Minibatch Loss 0.006  Accuracy 100%
Iteration 1; 300/469 	Minibatch Loss 0.055  Accuracy 98%
Iteration 1; 301/469 	Minibatch Loss 0.056  Accuracy 98%
Iteration 1; 302/469 	Minibatch Loss 0.049  Accuracy 98%
Iteration 1; 303/469 	Minibatch Loss 0.013  Accuracy 100%
Iteration 1; 304/469 	Minibatch Loss 0.121  Accuracy 95%
Iteration 1; 305/469 	Minibatch Loss 0.035  Accuracy 98%
Iteration 1; 306/469 	Minibatch Loss 0.005  Accuracy 100%
Iteration 1; 307/469 	Miniba

Iteration 1; 442/469 	Minibatch Loss 0.025  Accuracy 99%
Iteration 1; 443/469 	Minibatch Loss 0.029  Accuracy 98%
Iteration 1; 444/469 	Minibatch Loss 0.081  Accuracy 97%
Iteration 1; 445/469 	Minibatch Loss 0.055  Accuracy 98%
Iteration 1; 446/469 	Minibatch Loss 0.034  Accuracy 99%
Iteration 1; 447/469 	Minibatch Loss 0.042  Accuracy 99%
Iteration 1; 448/469 	Minibatch Loss 0.029  Accuracy 99%
Iteration 1; 449/469 	Minibatch Loss 0.128  Accuracy 97%
Iteration 1; 450/469 	Minibatch Loss 0.033  Accuracy 99%
Iteration 1; 451/469 	Minibatch Loss 0.047  Accuracy 97%
Iteration 1; 452/469 	Minibatch Loss 0.015  Accuracy 100%
Iteration 1; 453/469 	Minibatch Loss 0.022  Accuracy 99%
Iteration 1; 454/469 	Minibatch Loss 0.019  Accuracy 99%
Iteration 1; 455/469 	Minibatch Loss 0.039  Accuracy 98%
Iteration 1; 456/469 	Minibatch Loss 0.065  Accuracy 98%
Iteration 1; 457/469 	Minibatch Loss 0.078  Accuracy 98%
Iteration 1; 458/469 	Minibatch Loss 0.048  Accuracy 98%
Iteration 1; 459/469 	Minibatc

Iteration 2; 125/469 	Minibatch Loss 0.063  Accuracy 98%
Iteration 2; 126/469 	Minibatch Loss 0.042  Accuracy 98%
Iteration 2; 127/469 	Minibatch Loss 0.056  Accuracy 98%
Iteration 2; 128/469 	Minibatch Loss 0.137  Accuracy 96%
Iteration 2; 129/469 	Minibatch Loss 0.006  Accuracy 100%
Iteration 2; 130/469 	Minibatch Loss 0.043  Accuracy 99%
Iteration 2; 131/469 	Minibatch Loss 0.024  Accuracy 99%
Iteration 2; 132/469 	Minibatch Loss 0.044  Accuracy 98%
Iteration 2; 133/469 	Minibatch Loss 0.038  Accuracy 98%
Iteration 2; 134/469 	Minibatch Loss 0.030  Accuracy 98%
Iteration 2; 135/469 	Minibatch Loss 0.060  Accuracy 98%
Iteration 2; 136/469 	Minibatch Loss 0.049  Accuracy 98%
Iteration 2; 137/469 	Minibatch Loss 0.016  Accuracy 100%
Iteration 2; 138/469 	Minibatch Loss 0.024  Accuracy 99%
Iteration 2; 139/469 	Minibatch Loss 0.052  Accuracy 98%
Iteration 2; 140/469 	Minibatch Loss 0.030  Accuracy 100%
Iteration 2; 141/469 	Minibatch Loss 0.079  Accuracy 98%
Iteration 2; 142/469 	Miniba

Iteration 2; 277/469 	Minibatch Loss 0.032  Accuracy 98%
Iteration 2; 278/469 	Minibatch Loss 0.065  Accuracy 98%
Iteration 2; 279/469 	Minibatch Loss 0.031  Accuracy 98%
Iteration 2; 280/469 	Minibatch Loss 0.009  Accuracy 100%
Iteration 2; 281/469 	Minibatch Loss 0.041  Accuracy 98%
Iteration 2; 282/469 	Minibatch Loss 0.016  Accuracy 99%
Iteration 2; 283/469 	Minibatch Loss 0.050  Accuracy 98%
Iteration 2; 284/469 	Minibatch Loss 0.067  Accuracy 98%
Iteration 2; 285/469 	Minibatch Loss 0.042  Accuracy 98%
Iteration 2; 286/469 	Minibatch Loss 0.022  Accuracy 99%
Iteration 2; 287/469 	Minibatch Loss 0.021  Accuracy 99%
Iteration 2; 288/469 	Minibatch Loss 0.046  Accuracy 98%
Iteration 2; 289/469 	Minibatch Loss 0.018  Accuracy 99%
Iteration 2; 290/469 	Minibatch Loss 0.042  Accuracy 98%
Iteration 2; 291/469 	Minibatch Loss 0.068  Accuracy 98%
Iteration 2; 292/469 	Minibatch Loss 0.062  Accuracy 98%
Iteration 2; 293/469 	Minibatch Loss 0.029  Accuracy 98%
Iteration 2; 294/469 	Minibatc

Iteration 2; 429/469 	Minibatch Loss 0.078  Accuracy 97%
Iteration 2; 430/469 	Minibatch Loss 0.031  Accuracy 99%
Iteration 2; 431/469 	Minibatch Loss 0.047  Accuracy 98%
Iteration 2; 432/469 	Minibatch Loss 0.004  Accuracy 100%
Iteration 2; 433/469 	Minibatch Loss 0.079  Accuracy 98%
Iteration 2; 434/469 	Minibatch Loss 0.028  Accuracy 99%
Iteration 2; 435/469 	Minibatch Loss 0.146  Accuracy 95%
Iteration 2; 436/469 	Minibatch Loss 0.041  Accuracy 98%
Iteration 2; 437/469 	Minibatch Loss 0.027  Accuracy 99%
Iteration 2; 438/469 	Minibatch Loss 0.016  Accuracy 100%
Iteration 2; 439/469 	Minibatch Loss 0.027  Accuracy 99%
Iteration 2; 440/469 	Minibatch Loss 0.031  Accuracy 98%
Iteration 2; 441/469 	Minibatch Loss 0.088  Accuracy 97%
Iteration 2; 442/469 	Minibatch Loss 0.053  Accuracy 98%
Iteration 2; 443/469 	Minibatch Loss 0.015  Accuracy 100%
Iteration 2; 444/469 	Minibatch Loss 0.026  Accuracy 99%
Iteration 2; 445/469 	Minibatch Loss 0.096  Accuracy 95%
Iteration 2; 446/469 	Miniba

Iteration 3; 111/469 	Minibatch Loss 0.062  Accuracy 98%
Iteration 3; 112/469 	Minibatch Loss 0.021  Accuracy 98%
Iteration 3; 113/469 	Minibatch Loss 0.030  Accuracy 99%
Iteration 3; 114/469 	Minibatch Loss 0.032  Accuracy 99%
Iteration 3; 115/469 	Minibatch Loss 0.069  Accuracy 98%
Iteration 3; 116/469 	Minibatch Loss 0.027  Accuracy 98%
Iteration 3; 117/469 	Minibatch Loss 0.010  Accuracy 100%
Iteration 3; 118/469 	Minibatch Loss 0.067  Accuracy 98%
Iteration 3; 119/469 	Minibatch Loss 0.043  Accuracy 98%
Iteration 3; 120/469 	Minibatch Loss 0.026  Accuracy 99%
Iteration 3; 121/469 	Minibatch Loss 0.041  Accuracy 99%
Iteration 3; 122/469 	Minibatch Loss 0.028  Accuracy 99%
Iteration 3; 123/469 	Minibatch Loss 0.011  Accuracy 100%
Iteration 3; 124/469 	Minibatch Loss 0.008  Accuracy 100%
Iteration 3; 125/469 	Minibatch Loss 0.064  Accuracy 98%
Iteration 3; 126/469 	Minibatch Loss 0.027  Accuracy 99%
Iteration 3; 127/469 	Minibatch Loss 0.012  Accuracy 100%
Iteration 3; 128/469 	Minib

Iteration 3; 263/469 	Minibatch Loss 0.043  Accuracy 98%
Iteration 3; 264/469 	Minibatch Loss 0.041  Accuracy 98%
Iteration 3; 265/469 	Minibatch Loss 0.100  Accuracy 95%
Iteration 3; 266/469 	Minibatch Loss 0.020  Accuracy 100%
Iteration 3; 267/469 	Minibatch Loss 0.016  Accuracy 100%
Iteration 3; 268/469 	Minibatch Loss 0.013  Accuracy 100%
Iteration 3; 269/469 	Minibatch Loss 0.027  Accuracy 99%
Iteration 3; 270/469 	Minibatch Loss 0.022  Accuracy 98%
Iteration 3; 271/469 	Minibatch Loss 0.021  Accuracy 99%
Iteration 3; 272/469 	Minibatch Loss 0.007  Accuracy 100%
Iteration 3; 273/469 	Minibatch Loss 0.025  Accuracy 99%
Iteration 3; 274/469 	Minibatch Loss 0.017  Accuracy 99%
Iteration 3; 275/469 	Minibatch Loss 0.032  Accuracy 98%
Iteration 3; 276/469 	Minibatch Loss 0.025  Accuracy 98%
Iteration 3; 277/469 	Minibatch Loss 0.101  Accuracy 96%
Iteration 3; 278/469 	Minibatch Loss 0.016  Accuracy 99%
Iteration 3; 279/469 	Minibatch Loss 0.014  Accuracy 100%
Iteration 3; 280/469 	Mini

Iteration 3; 415/469 	Minibatch Loss 0.033  Accuracy 98%
Iteration 3; 416/469 	Minibatch Loss 0.075  Accuracy 98%
Iteration 3; 417/469 	Minibatch Loss 0.026  Accuracy 99%
Iteration 3; 418/469 	Minibatch Loss 0.057  Accuracy 98%
Iteration 3; 419/469 	Minibatch Loss 0.038  Accuracy 99%
Iteration 3; 420/469 	Minibatch Loss 0.027  Accuracy 99%
Iteration 3; 421/469 	Minibatch Loss 0.042  Accuracy 99%
Iteration 3; 422/469 	Minibatch Loss 0.022  Accuracy 99%
Iteration 3; 423/469 	Minibatch Loss 0.006  Accuracy 100%
Iteration 3; 424/469 	Minibatch Loss 0.041  Accuracy 98%
Iteration 3; 425/469 	Minibatch Loss 0.033  Accuracy 98%
Iteration 3; 426/469 	Minibatch Loss 0.009  Accuracy 100%
Iteration 3; 427/469 	Minibatch Loss 0.038  Accuracy 98%
Iteration 3; 428/469 	Minibatch Loss 0.032  Accuracy 99%
Iteration 3; 429/469 	Minibatch Loss 0.025  Accuracy 99%
Iteration 3; 430/469 	Minibatch Loss 0.010  Accuracy 100%
Iteration 3; 431/469 	Minibatch Loss 0.085  Accuracy 95%
Iteration 3; 432/469 	Miniba

Iteration 4; 98/469 	Minibatch Loss 0.012  Accuracy 100%
Iteration 4; 99/469 	Minibatch Loss 0.050  Accuracy 98%
Iteration 4; 100/469 	Minibatch Loss 0.016  Accuracy 99%
Iteration 4; 101/469 	Minibatch Loss 0.006  Accuracy 100%
Iteration 4; 102/469 	Minibatch Loss 0.012  Accuracy 100%
Iteration 4; 103/469 	Minibatch Loss 0.006  Accuracy 100%
Iteration 4; 104/469 	Minibatch Loss 0.035  Accuracy 98%
Iteration 4; 105/469 	Minibatch Loss 0.029  Accuracy 99%
Iteration 4; 106/469 	Minibatch Loss 0.012  Accuracy 99%
Iteration 4; 107/469 	Minibatch Loss 0.024  Accuracy 99%
Iteration 4; 108/469 	Minibatch Loss 0.024  Accuracy 99%
Iteration 4; 109/469 	Minibatch Loss 0.009  Accuracy 100%
Iteration 4; 110/469 	Minibatch Loss 0.023  Accuracy 98%
Iteration 4; 111/469 	Minibatch Loss 0.027  Accuracy 99%
Iteration 4; 112/469 	Minibatch Loss 0.015  Accuracy 99%
Iteration 4; 113/469 	Minibatch Loss 0.020  Accuracy 99%
Iteration 4; 114/469 	Minibatch Loss 0.024  Accuracy 99%
Iteration 4; 115/469 	Miniba

Iteration 4; 250/469 	Minibatch Loss 0.064  Accuracy 98%
Iteration 4; 251/469 	Minibatch Loss 0.048  Accuracy 98%
Iteration 4; 252/469 	Minibatch Loss 0.026  Accuracy 99%
Iteration 4; 253/469 	Minibatch Loss 0.065  Accuracy 97%
Iteration 4; 254/469 	Minibatch Loss 0.031  Accuracy 98%
Iteration 4; 255/469 	Minibatch Loss 0.093  Accuracy 97%
Iteration 4; 256/469 	Minibatch Loss 0.018  Accuracy 99%
Iteration 4; 257/469 	Minibatch Loss 0.048  Accuracy 98%
Iteration 4; 258/469 	Minibatch Loss 0.089  Accuracy 95%
Iteration 4; 259/469 	Minibatch Loss 0.092  Accuracy 98%
Iteration 4; 260/469 	Minibatch Loss 0.015  Accuracy 99%
Iteration 4; 261/469 	Minibatch Loss 0.044  Accuracy 98%
Iteration 4; 262/469 	Minibatch Loss 0.021  Accuracy 99%
Iteration 4; 263/469 	Minibatch Loss 0.067  Accuracy 98%
Iteration 4; 264/469 	Minibatch Loss 0.010  Accuracy 99%
Iteration 4; 265/469 	Minibatch Loss 0.028  Accuracy 99%
Iteration 4; 266/469 	Minibatch Loss 0.031  Accuracy 99%
Iteration 4; 267/469 	Minibatch

Iteration 4; 402/469 	Minibatch Loss 0.016  Accuracy 100%
Iteration 4; 403/469 	Minibatch Loss 0.013  Accuracy 100%
Iteration 4; 404/469 	Minibatch Loss 0.029  Accuracy 98%
Iteration 4; 405/469 	Minibatch Loss 0.012  Accuracy 100%
Iteration 4; 406/469 	Minibatch Loss 0.028  Accuracy 98%
Iteration 4; 407/469 	Minibatch Loss 0.022  Accuracy 99%
Iteration 4; 408/469 	Minibatch Loss 0.057  Accuracy 98%
Iteration 4; 409/469 	Minibatch Loss 0.043  Accuracy 99%
Iteration 4; 410/469 	Minibatch Loss 0.067  Accuracy 99%
Iteration 4; 411/469 	Minibatch Loss 0.028  Accuracy 99%
Iteration 4; 412/469 	Minibatch Loss 0.015  Accuracy 100%
Iteration 4; 413/469 	Minibatch Loss 0.012  Accuracy 100%
Iteration 4; 414/469 	Minibatch Loss 0.058  Accuracy 98%
Iteration 4; 415/469 	Minibatch Loss 0.027  Accuracy 98%
Iteration 4; 416/469 	Minibatch Loss 0.016  Accuracy 100%
Iteration 4; 417/469 	Minibatch Loss 0.029  Accuracy 99%
Iteration 4; 418/469 	Minibatch Loss 0.066  Accuracy 98%
Iteration 4; 419/469 	Min

Iteration 5; 79/469 	Minibatch Loss 0.033  Accuracy 98%
Iteration 5; 80/469 	Minibatch Loss 0.042  Accuracy 98%
Iteration 5; 81/469 	Minibatch Loss 0.021  Accuracy 99%
Iteration 5; 82/469 	Minibatch Loss 0.031  Accuracy 98%
Iteration 5; 83/469 	Minibatch Loss 0.075  Accuracy 98%
Iteration 5; 84/469 	Minibatch Loss 0.004  Accuracy 100%
Iteration 5; 85/469 	Minibatch Loss 0.008  Accuracy 100%
Iteration 5; 86/469 	Minibatch Loss 0.005  Accuracy 100%
Iteration 5; 87/469 	Minibatch Loss 0.044  Accuracy 98%
Iteration 5; 88/469 	Minibatch Loss 0.032  Accuracy 98%
Iteration 5; 89/469 	Minibatch Loss 0.007  Accuracy 100%
Iteration 5; 90/469 	Minibatch Loss 0.012  Accuracy 100%
Iteration 5; 91/469 	Minibatch Loss 0.014  Accuracy 99%
Iteration 5; 92/469 	Minibatch Loss 0.034  Accuracy 98%
Iteration 5; 93/469 	Minibatch Loss 0.054  Accuracy 98%
Iteration 5; 94/469 	Minibatch Loss 0.017  Accuracy 99%
Iteration 5; 95/469 	Minibatch Loss 0.062  Accuracy 98%
Iteration 5; 96/469 	Minibatch Loss 0.031  

Iteration 5; 223/469 	Minibatch Loss 0.018  Accuracy 100%
Iteration 5; 224/469 	Minibatch Loss 0.016  Accuracy 99%
Iteration 5; 225/469 	Minibatch Loss 0.032  Accuracy 99%
Iteration 5; 226/469 	Minibatch Loss 0.020  Accuracy 99%
Iteration 5; 227/469 	Minibatch Loss 0.047  Accuracy 98%
Iteration 5; 228/469 	Minibatch Loss 0.050  Accuracy 98%
Iteration 5; 229/469 	Minibatch Loss 0.016  Accuracy 99%
Iteration 5; 230/469 	Minibatch Loss 0.041  Accuracy 99%
Iteration 5; 231/469 	Minibatch Loss 0.045  Accuracy 99%
Iteration 5; 232/469 	Minibatch Loss 0.026  Accuracy 98%
Iteration 5; 233/469 	Minibatch Loss 0.023  Accuracy 99%
Iteration 5; 234/469 	Minibatch Loss 0.020  Accuracy 98%
Iteration 5; 235/469 	Minibatch Loss 0.064  Accuracy 99%
Iteration 5; 236/469 	Minibatch Loss 0.023  Accuracy 99%
Iteration 5; 237/469 	Minibatch Loss 0.031  Accuracy 98%
Iteration 5; 238/469 	Minibatch Loss 0.009  Accuracy 100%
Iteration 5; 239/469 	Minibatch Loss 0.010  Accuracy 100%
Iteration 5; 240/469 	Miniba

Iteration 5; 368/469 	Minibatch Loss 0.036  Accuracy 99%
Iteration 5; 369/469 	Minibatch Loss 0.023  Accuracy 100%
Iteration 5; 370/469 	Minibatch Loss 0.034  Accuracy 98%
Iteration 5; 371/469 	Minibatch Loss 0.053  Accuracy 98%
Iteration 5; 372/469 	Minibatch Loss 0.055  Accuracy 99%
Iteration 5; 373/469 	Minibatch Loss 0.013  Accuracy 100%
Iteration 5; 374/469 	Minibatch Loss 0.036  Accuracy 98%
Iteration 5; 375/469 	Minibatch Loss 0.041  Accuracy 99%
Iteration 5; 376/469 	Minibatch Loss 0.023  Accuracy 99%
Iteration 5; 377/469 	Minibatch Loss 0.010  Accuracy 100%
Iteration 5; 378/469 	Minibatch Loss 0.077  Accuracy 97%
Iteration 5; 379/469 	Minibatch Loss 0.027  Accuracy 98%
Iteration 5; 380/469 	Minibatch Loss 0.069  Accuracy 98%
Iteration 5; 381/469 	Minibatch Loss 0.013  Accuracy 100%
Iteration 5; 382/469 	Minibatch Loss 0.066  Accuracy 97%
Iteration 5; 383/469 	Minibatch Loss 0.015  Accuracy 99%
Iteration 5; 384/469 	Minibatch Loss 0.022  Accuracy 99%
Iteration 5; 385/469 	Minib

In [8]:
#predict in distribution
MNIST_PATH = "weights/mnist_test_6iter_10c_simpleCNN_256.pth"

#mnist_model = LPADirNN(x=100)
mnist_model = LPADirNN(num_LL=256).cuda()
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

max_len = len(mnist_test_loader)
for batch_idx, (x, y) in enumerate(mnist_test_loader):

    x, y = x.cuda(), y.cuda()
    output = mnist_model(x)

    accuracy = get_accuracy(output, y)
    if batch_idx % 10 == 0:
        print(
            "Batch {}/{} \t".format(batch_idx, max_len) + 
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )
    acc.append(accuracy)

avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: weights/mnist_test_6iter_10c_simpleCNN_256.pth
Batch 0/79 	Accuracy 100%
Batch 10/79 	Accuracy 98%
Batch 20/79 	Accuracy 98%
Batch 30/79 	Accuracy 99%
Batch 40/79 	Accuracy 100%
Batch 50/79 	Accuracy 98%
Batch 60/79 	Accuracy 100%
Batch 70/79 	Accuracy 97%
overall test accuracy on MNIST: 98.95 %


In [9]:
## play around with Backpack
def get_Hessian_NN(model, train_loader, var0, device='cpu', verbose=True):
    lossfunc = torch.nn.CrossEntropyLoss()

    extend(lossfunc, debug=False)
    extend(model, debug=False)

    Hessian_diag = []
    for param in mnist_model.parameters():
        ps = param.size()
        print("parameter size: ", ps)
        Hessian_diag.append(torch.zeros(ps, device=device))
        #print(param.numel())

    tau = 1/var0
    max_len = len(train_loader)

    with backpack(DiagHessian()):

        for batch_idx, (x, y) in enumerate(train_loader):

            if device == 'cuda':
                x, y = x.cuda(), y.cuda()

            model.zero_grad()
            loss = lossfunc(model(x), y)
            loss.backward()

            with torch.no_grad():
                # Hessian of weight
                for idx, param in enumerate(model.parameters()):

                    H_ = param.diag_h
                    #add bias here
                    H_ += tau * torch.ones(H_.size(), device=device)

                    rho = min(1-1/(batch_idx+1), 0.995)

                    Hessian_diag[idx] = rho*Hessian_diag[idx] + (1-rho)*H_
            
            if verbose:
                print("Batch: {}/{}".format(batch_idx, max_len))

    #combine all elements of the Hessian to one big vector
    Hessian_diag = torch.cat([el.view(-1) for el in Hessian_diag])
    print("Hessian_size: ", Hessian_diag.size())
    num_params = np.sum([p.numel() for p in model.parameters()])
    assert(num_params == Hessian_diag.size(-1))
    return(Hessian_diag)
        

In [10]:
Hessian_MNIST = get_Hessian_NN(model=mnist_model, train_loader=mnist_train_loader, var0=200, verbose=False, device='cuda')

parameter size:  torch.Size([32, 1, 5, 5])
parameter size:  torch.Size([32])
parameter size:  torch.Size([64, 32, 5, 5])
parameter size:  torch.Size([64])
parameter size:  torch.Size([256, 1024])
parameter size:  torch.Size([256])
parameter size:  torch.Size([10, 256])
parameter size:  torch.Size([10])
Hessian_size:  torch.Size([317066])


In [11]:
print(Hessian_MNIST)

tensor([0.0057, 0.0060, 0.0068,  ..., 0.0067, 0.0073, 0.0071], device='cuda:0')


In [12]:
def compute_jacobians_with_backpack(model, x, y, lossfunc):
    """
    Returns the jacobians of the network

    The output is a list. Each element in the list is a tensor
    corresponding to the model.parameters().

    The tensor are of the form [N, *, C] where N is the batch dimension,
    C is the number of classes (output size of the network)
    and * is the shape of the model parameters
    """
    loss = lossfunc(model(x), y)

    with backpack(NetJac()):
        loss.backward()

    jacs = []
    for p in model.parameters():
        jacs.append(p.netjacs.data.detach())
    return jacs

def transform2full_jac(backpack_jacobian):

    jac_full = []
    #batch_size
    N = backpack_jacobian[0].size(0)
    #num classes
    k = backpack_jacobian[0].size(-1)
    for j in backpack_jacobian:
        jac_full.append(j.view(N, -1, k).permute(0,2,1))
    jac_full = torch.cat(jac_full, dim=-1)
    return(jac_full)

def get_Jacobian(model, x, y, lossfunc):
    return(transform2full_jac(compute_jacobians_with_backpack(model, x, y, lossfunc)))

In [13]:
def predict_Diagonal_full(model, test_loader, Hessian, verbose=True, num_samples=100, cuda=False, timing=False):
    
    lossfunc = torch.nn.CrossEntropyLoss()
    extend(lossfunc, debug=False)
    
    py = []
    if timing:
        time_sum = 0
    
    max_len = len(test_loader)
    for batch_idx, (x, y) in enumerate(test_loader):
        
        if cuda:
            x, y = x.cuda(), y.cuda()
        
        J = get_Jacobian(model, x, y, lossfunc)
        J = J.detach()
        batch_size = J.size(0)
        num_classes = J.size(1)
        Cov_pred = torch.bmm(J * Hessian, J.permute(0, 2, 1))
        Cov_pred = Cov_pred.detach()
        if verbose:
            print("Jacobian size: ", J.size())
            print("cov pred size: ", Cov_pred.size())
        
        mu_pred = model(x).detach()
        post_pred = MultivariateNormal(mu_pred, Cov_pred)

        # MC-integral
        t0 = time.time()
        py_ = 0

        for _ in range(num_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)


        py_ /= num_samples
        py_ = py_.detach()

        py.append(py_)
        t1 = time.time()
        if timing:
            time_sum += (t1-t0)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))
    
    if timing:
        print("total time used for transform: {:.05f}".format(time_sum))

    return torch.cat(py, dim=0)

In [14]:
BATCH_SIZE_TEST_FMNIST = 128
BATCH_SIZE_TEST_KMNIST = 128

In [15]:
FMNIST_test = torchvision.datasets.FashionMNIST(
        '~/data/fmnist', train=False, download=True,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [16]:
KMNIST_test = torchvision.datasets.KMNIST(
        '~/data/kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=BATCH_SIZE_TEST_KMNIST, shuffle=False)

In [17]:
"""Load notMNIST"""

import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from matplotlib.pyplot import imread
from torch import Tensor

"""
Loads the train/test set. 
Every image in the dataset is 28x28 pixels and the labels are numbered from 0-9
for A-J respectively.
Set root to point to the Train/Test folders.
"""

# Creating a sub class of torch.utils.data.dataset.Dataset
class notMNIST(Dataset):

    # The init method is called when this class will be instantiated
    def __init__(self, root, transform):
        
        #super(notMNIST, self).__init__(root, transform=transform)

        self.transform = transform
        
        Images, Y = [], []
        folders = os.listdir(root)

        for folder in folders:
            folder_path = os.path.join(root, folder)
            for ims in os.listdir(folder_path):
                try:
                    img_path = os.path.join(folder_path, ims)
                    Images.append(np.array(imread(img_path)))
                    Y.append(ord(folder) - 65)  # Folders are A-J so labels will be 0-9
                except:
                    # Some images in the dataset are damaged
                    print("File {}/{} is broken".format(folder, ims))
        data = [(x, y) for x, y in zip(Images, Y)]
        self.data = data
        self.targets = torch.Tensor(Y)

    # The number of items in the dataset
    def __len__(self):
        return len(self.data)

    # The Dataloader is a generator that repeatedly calls the getitem method.
    # getitem is supposed to return (X, Y) for the specified index.
    def __getitem__(self, index):
        img = self.data[index][0]

        if self.transform is not None:
            img = self.transform(img)
            
        # Input for Conv2D should be Channels x Height x Width
        img_tensor = Tensor(img).view(1, 28, 28).float()
        label = self.data[index][1]
        return (img_tensor, label)

In [18]:
#root = os.path.abspath('~/data')
root = os.path.expanduser('~/data')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
not_mnist_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=BATCH_SIZE_TEST_KMNIST,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


In [19]:
def get_in_dist_values(py_in, targets):
    acc_in = np.mean(np.argmax(py_in, 1) == targets)
    prob_correct = np.choose(targets, py_in.T).mean()
    average_entropy = -np.sum(py_in*np.log(py_in+1e-8), axis=1).mean()
    MMC = py_in.max(1).mean()
    return(acc_in, prob_correct, average_entropy, MMC)
    
def get_out_dist_values(py_in, py_out, targets):
    average_entropy = -np.sum(py_out*np.log(py_out+1e-8), axis=1).mean()
    acc_out = np.mean(np.argmax(py_out, 1) == targets)
    prob_correct = np.choose(targets, py_out.T).mean()
    labels = np.zeros(len(py_in)+len(py_out), dtype='int32')
    labels[:len(py_in)] = 1
    examples = np.concatenate([py_in.max(1), py_out.max(1)])
    auroc = roc_auc_score(labels, examples)
    MMC = py_out.max(1).mean()
    return(acc_out, prob_correct, average_entropy, MMC, auroc)

def print_in_dist_values(acc_in, prob_correct, average_entropy, MMC, train='mnist', method='LLLA-KF'):
    
    print(f'[In, {method}, {train}] Accuracy: {acc_in:.3f}; average entropy: {average_entropy:.3f}; \
    MMC: {MMC:.3f}; Prob @ correct: {prob_correct:.3f}')


def print_out_dist_values(acc_out, prob_correct, average_entropy, MMC, auroc, train='mnist', test='FMNIST', method='LLLA-KF'):
   
    print(f'[Out-{test}, {method}, {train}] Accuracy: {acc_out:.3f}; Average entropy: {average_entropy:.3f};\
    MMC: {MMC:.3f}; AUROC: {auroc:.3f}; Prob @ correct: {prob_correct:.3f}')

# MAP estimate

In [20]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [21]:
mnist_test_in_MAP = predict_MAP(mnist_model, mnist_test_loader, cuda=True).cpu().numpy()
mnist_test_out_fmnist_MAP = predict_MAP(mnist_model, FMNIST_test_loader, cuda=True).cpu().numpy()
mnist_test_out_notMNIST_MAP = predict_MAP(mnist_model, not_mnist_test_loader, cuda=True).cpu().numpy()
mnist_test_out_KMNIST_MAP = predict_MAP(mnist_model, KMNIST_test_loader, cuda=True).cpu().numpy()

In [22]:
acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP = get_in_dist_values(mnist_test_in_MAP, targets)
acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_fmnist_MAP, targets_FMNIST)
acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_notMNIST_MAP, targets_notMNIST)
acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_KMNIST_MAP, targets_KMNIST)

In [23]:
print_in_dist_values(acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP, 'mnist', 'MAP')
print_out_dist_values(acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP, 'FMNIST', 'MAP')
print_out_dist_values(acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP, 'notMNIST', 'MAP')
print_out_dist_values(acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP, 'KMNIST', 'MAP')

[In, MAP, mnist] Accuracy: 0.989; average entropy: 0.035;     MMC: 0.989; Prob @ correct: 0.984
[Out-MAP, LLLA-KF, FMNIST] Accuracy: 0.099; Average entropy: 1.392;    MMC: 0.513; AUROC: 0.992; Prob @ correct: 0.107
[Out-MAP, LLLA-KF, notMNIST] Accuracy: 0.110; Average entropy: 0.807;    MMC: 0.709; AUROC: 0.952; Prob @ correct: 0.114
[Out-MAP, LLLA-KF, KMNIST] Accuracy: 0.085; Average entropy: 0.867;    MMC: 0.687; AUROC: 0.974; Prob @ correct: 0.086


In [2]:
import numpy as np

In [3]:
#MAP estimate
#seeds are 123,124,125,126,127
acc_in = [0.991, 0.990, 0.993, 0.988, 0.989]
mmc_in = [0.988, 0.990, 0.989, 0.989, 0.989]
mmc_out_fmnist = [0.516, 0.571, 0.534, 0.554, 0.513]
mmc_out_notmnist = [0.692, 0.731, 0.696, 0.702, 0.709]
mmc_out_kmnist = [0.678, 0.697, 0.659, 0.700, 0.687]

auroc_out_fmnist = [0.989, 0.990, 0.990, 0.989, 0.992]
auroc_out_notmnist = [0.964, 0.942, 0.959, 0.952, 0.952]
auroc_out_kmnist = [0.970, 0.977, 0.980, 0.967, 0.974]

print("accuracy: {:.03f} with std {:.03f}".format(np.mean(acc_in), np.std(acc_in)))

print("MMC in: {:.03f} with std {:.03f}".format(np.mean(mmc_in), np.std(mmc_in)))
print("MMC out fmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_fmnist), np.std(mmc_out_fmnist)))
print("MMC out notmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_notmnist), np.std(mmc_out_notmnist)))
print("MMC out kmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_kmnist), np.std(mmc_out_kmnist)))

print("AUROC out fmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_fmnist), np.std(auroc_out_fmnist)))
print("AUROC out notmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_notmnist), np.std(auroc_out_notmnist)))
print("AUROC out kmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_kmnist), np.std(auroc_out_kmnist)))

accuracy: 0.990 with std 0.002
MMC in: 0.989 with std 0.001
MMC out fmnist: 0.538 with std 0.022
MMC out notmnist: 0.706 with std 0.014
MMC out kmnist: 0.684 with std 0.015
AUROC out fmnist: 0.990 with std 0.001
AUROC out notmnist: 0.954 with std 0.007
AUROC out kmnist: 0.974 with std 0.005


# Diag Hessian Sampling estimate

In [25]:
mnist_test_in_D = predict_Diagonal_full(mnist_model, mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=1000).cpu().numpy()
mnist_test_out_FMNIST_D = predict_Diagonal_full(mnist_model, FMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=1000).cpu().numpy()
mnist_test_out_notMNIST_D = predict_Diagonal_full(mnist_model, not_mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=1000).cpu().numpy()
mnist_test_out_KMNIST_D = predict_Diagonal_full(mnist_model, KMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=1000).cpu().numpy()

total time used for transform: 6.67633
total time used for transform: 6.67161
total time used for transform: 12.42572
total time used for transform: 6.68149


In [26]:
acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D = get_in_dist_values(mnist_test_in_D, targets)
acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_FMNIST_D, targets_FMNIST)
acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_notMNIST_D, targets_notMNIST)
acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_KMNIST_D, targets_KMNIST)

In [27]:
print_in_dist_values(acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D, 'mnist', 'Diag')
print_out_dist_values(acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D, test='fmnist', method='Diag')
print_out_dist_values(acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D, test='notMNIST', method='Diag')
print_out_dist_values(acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D, test='KMNIST', method='Diag')

[In, Diag, mnist] Accuracy: 0.990; average entropy: 0.224;     MMC: 0.942; Prob @ correct: 0.939
[Out-fmnist, Diag, mnist] Accuracy: 0.102; Average entropy: 1.701;    MMC: 0.401; AUROC: 0.992; Prob @ correct: 0.111
[Out-notMNIST, Diag, mnist] Accuracy: 0.108; Average entropy: 1.283;    MMC: 0.544; AUROC: 0.959; Prob @ correct: 0.116
[Out-KMNIST, Diag, mnist] Accuracy: 0.086; Average entropy: 1.337;    MMC: 0.512; AUROC: 0.975; Prob @ correct: 0.086


In [42]:
#Diag Sampling
#seeds are 123,124,125,126,127
time_lpb_in = [6.78413, 6.67228, 6.51112,6.44895, 6.67633]
time_lpb_out_fmnist = [6.77055, 6.64844, 6.47705,6.44392,6.67161]
time_lpb_out_notmnist = [12.65939, 12.49839, 11.99764, 12.07371, 12.42572]
time_lpb_out_kmnist = [6.79133, 6.73237, 6.46135, 6.46563, 6.68149]

acc_in = [0.991, 0.990, 0.993, 0.988, 0.990]
mmc_in = [0.928, 0.938, 0.927, 0.924, 0.042]
mmc_out_fmnist = [0.397, 0.426, 0.406, 0.406, 0.401]
mmc_out_notmnist = [0.517, 0.560, 0.526, 0.518, 0.554]
mmc_out_kmnist = [0.514, 0.503, 0.475, 0.497, 0.512]

auroc_out_fmnist = [0.986, 0.990, 0.988, 0.990, 0.992]
auroc_out_notmnist = [0.968, 0.948, 0.959, 0.958, 0.959]
auroc_out_kmnist = [0.967, 0.978, 0.980, 0.970, 0.975]

print("Sampling Bridge time in: {:.03f} with std {:.03f}".format(np.mean(time_lpb_in), np.std(time_lpb_in)))
print("Sampling Bridge time out fmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_fmnist), np.std(time_lpb_out_fmnist)))
print("Sampling Bridge time out notmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_notmnist), np.std(time_lpb_out_notmnist)))
print("Sampling Bridge time out kmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_kmnist), np.std(time_lpb_out_kmnist)))

print("accuracy: {:.03f} with std {:.03f}".format(np.mean(acc_in), np.std(acc_in)))

print("MMC in: {:.03f} with std {:.03f}".format(np.mean(mmc_in), np.std(mmc_in)))
print("MMC out fmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_fmnist), np.std(mmc_out_fmnist)))
print("MMC out notmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_notmnist), np.std(mmc_out_notmnist)))
print("MMC out kmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_kmnist), np.std(mmc_out_kmnist)))

print("AUROC out fmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_fmnist), np.std(auroc_out_fmnist)))
print("AUROC out notmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_notmnist), np.std(auroc_out_notmnist)))
print("AUROC out kmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_kmnist), np.std(auroc_out_kmnist)))

Sampling Bridge time in: 6.619 with std 0.122
Sampling Bridge time out fmnist: 6.602 with std 0.123
Sampling Bridge time out notmnist: 12.331 with std 0.254
Sampling Bridge time out kmnist: 6.626 with std 0.138
accuracy: 0.990 with std 0.002
MMC in: 0.752 with std 0.355
MMC out fmnist: 0.407 with std 0.010
MMC out notmnist: 0.535 with std 0.018
MMC out kmnist: 0.500 with std 0.014
AUROC out fmnist: 0.989 with std 0.002
AUROC out notmnist: 0.958 with std 0.006
AUROC out kmnist: 0.974 with std 0.005


# Dirichlet Laplace Approximation

In [30]:
def get_alpha_from_Normal(mu, Sigma):
    batch_size, K = mu.size(0), mu.size(-1)
    Sigma_d = torch.diagonal(Sigma, dim1=1, dim2=2)
    sum_exp = torch.sum(torch.exp(-1*mu), dim=1).view(-1,1)
    alpha = 1/Sigma_d * (1 - 2/K + torch.exp(mu)/K**2 * sum_exp)
    
    assert(alpha.size() == mu.size())
    
    return(alpha)

In [31]:
def predict_DIR_LPA(model, test_loader, Hessian, verbose=True, cuda=False, timing=False):

    lossfunc = torch.nn.CrossEntropyLoss()
    extend(lossfunc, debug=False)
    
    alphas = []
    if timing:
        time_sum = 0

    max_len = len(test_loader)
    for batch_idx, (x, y) in enumerate(test_loader):
        
        if cuda:
            x, y = x.cuda(), y.cuda()

        J = get_Jacobian(model, x, y, lossfunc)
        J = J.detach()
        batch_size = J.size(0)
        num_classes = J.size(1)
        Cov_pred = torch.bmm(J * Hessian, J.permute(0, 2, 1))
        Cov_pred = Cov_pred.detach()
        
        mu_pred = model(x).detach()
        
        t0 = time.time()
        alpha = get_alpha_from_Normal(mu_pred, Cov_pred)
        t1 = time.time()
        alpha = alpha.detach()
        if timing:
            time_sum += (t1 - t0)

        alphas.append(alpha)
        
        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))
    
    if timing:
        print("total time used for transform: {:.05f}".format(time_sum))

    return(torch.cat(alphas, dim = 0))


In [32]:
mnist_test_in_DIR_LPA = predict_DIR_LPA(mnist_model, mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).cpu().numpy()
mnist_test_out_FMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, FMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).cpu().numpy()
mnist_test_out_notMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, not_mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).cpu().numpy()
mnist_test_out_KMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, KMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).cpu().numpy()

total time used for transform: 0.01609
total time used for transform: 0.01563
total time used for transform: 0.02892
total time used for transform: 0.01547


In [33]:
mnist_test_in_DIR_LPAn = mnist_test_in_DIR_LPA/mnist_test_in_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_FMNIST_DIR_LPAn = mnist_test_out_FMNIST_DIR_LPA/mnist_test_out_FMNIST_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_notMNIST_DIR_LPAn = mnist_test_out_notMNIST_DIR_LPA/mnist_test_out_notMNIST_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_KMNIST_DIR_LPAn = mnist_test_out_KMNIST_DIR_LPA/mnist_test_out_KMNIST_DIR_LPA.sum(1).reshape(-1,1)

In [34]:
acc_in_DIR_LPAn, prob_correct_in_DIR_LPAn, ent_in_DIR_LPAn, MMC_in_DIR_LPAn = get_in_dist_values(mnist_test_in_DIR_LPAn, targets)
acc_out_FMNIST_DIR_LPAn, prob_correct_out_FMNIST_DIR_LPAn, ent_out_FMNIST_DIR_LPAn, MMC_out_FMNIST_DIR_LPAn, auroc_out_FMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_FMNIST_DIR_LPAn, targets_FMNIST)
acc_out_notMNIST_DIR_LPAn, prob_correct_out_notMNIST_DIR_LPAn, ent_out_notMNIST_DIR_LPAn, MMC_out_notMNIST_DIR_LPAn, auroc_out_notMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_notMNIST_DIR_LPAn, targets_notMNIST)
acc_out_KMNIST_DIR_LPAn, prob_correct_out_KMNIST_DIR_LPAn, ent_out_KMNIST_DIR_LPAn, MMC_out_KMNIST_DIR_LPAn, auroc_out_KMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_KMNIST_DIR_LPAn, targets_KMNIST)

In [35]:
print_in_dist_values(acc_in_DIR_LPAn, prob_correct_in_DIR_LPAn, ent_in_DIR_LPAn, MMC_in_DIR_LPAn, 'mnist', 'DIR_LPAn')
print_out_dist_values(acc_out_FMNIST_DIR_LPAn, prob_correct_out_FMNIST_DIR_LPAn, ent_out_FMNIST_DIR_LPAn, MMC_out_FMNIST_DIR_LPAn, auroc_out_FMNIST_DIR_LPAn, test='fmnist', method='DIR_LPAn')
print_out_dist_values(acc_out_notMNIST_DIR_LPAn, prob_correct_out_notMNIST_DIR_LPAn, ent_out_notMNIST_DIR_LPAn, MMC_out_notMNIST_DIR_LPAn, auroc_out_notMNIST_DIR_LPAn, test='notMNIST', method='DIR_LPAn')
print_out_dist_values(acc_out_KMNIST_DIR_LPAn, prob_correct_out_KMNIST_DIR_LPAn, ent_out_KMNIST_DIR_LPAn, MMC_out_KMNIST_DIR_LPAn, auroc_out_KMNIST_DIR_LPAn, test='KMNIST', method='DIR_LPAn')

[In, DIR_LPAn, mnist] Accuracy: 0.989; average entropy: 0.044;     MMC: 0.987; Prob @ correct: 0.982
[Out-fmnist, DIR_LPAn, mnist] Accuracy: 0.080; Average entropy: 1.811;    MMC: 0.364; AUROC: 0.995; Prob @ correct: 0.095
[Out-notMNIST, DIR_LPAn, mnist] Accuracy: 0.122; Average entropy: 1.038;    MMC: 0.644; AUROC: 0.958; Prob @ correct: 0.123
[Out-KMNIST, DIR_LPAn, mnist] Accuracy: 0.086; Average entropy: 1.061;    MMC: 0.633; AUROC: 0.974; Prob @ correct: 0.086


In [4]:
#Laplace Bridge
#seeds are 123,124,125,126,127
time_lpb_in = [0.01622, 0.01575, 0.01573, 0.01557, 0.01609]
time_lpb_out_fmnist = [0.01598, 0.01579, 0.01562, 0.01549, 0.01563]
time_lpb_out_notmnist = [0.02956, 0.02926, 0.02902, 0.02869, 0.02892]
time_lpb_out_kmnist = [0.01596, 0.01571, 0.01557, 0.01550, 0.01547]

acc_in = [0.991, 0.990, 0.993, 0.988, 0.989]
mmc_in = [0.984, 0.988, 0.987, 0.988, 0.987]
mmc_out_fmnist = [0.368, 0.413, 0.363, 0.378, 0.364]
mmc_out_notmnist = [0.626, 0.678, 0.636, 0.638, 0.644]
mmc_out_kmnist = [0.628, 0.652, 0.598, 0.639, 0.633]

auroc_out_fmnist = [0.991, 0.994, 0.995, 0.995, 0.995]
auroc_out_notmnist = [0.970, 0.951, 0.966, 0.963, 0.958]
auroc_out_kmnist = [0.969, 0.977, 0.981, 0.972, 0.974]

print("Laplace Bridge time in: {:.03f} with std {:.03f}".format(np.mean(time_lpb_in), np.std(time_lpb_in)))
print("Laplace Bridge time out fmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_fmnist), np.std(time_lpb_out_fmnist)))
print("Laplace Bridge time out notmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_notmnist), np.std(time_lpb_out_notmnist)))
print("Laplace Bridge time out kmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_kmnist), np.std(time_lpb_out_kmnist)))

print("accuracy: {:.03f} with std {:.03f}".format(np.mean(acc_in), np.std(acc_in)))

print("MMC in: {:.03f} with std {:.03f}".format(np.mean(mmc_in), np.std(mmc_in)))
print("MMC out fmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_fmnist), np.std(mmc_out_fmnist)))
print("MMC out notmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_notmnist), np.std(mmc_out_notmnist)))
print("MMC out kmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_kmnist), np.std(mmc_out_kmnist)))

print("AUROC out fmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_fmnist), np.std(auroc_out_fmnist)))
print("AUROC out notmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_notmnist), np.std(auroc_out_notmnist)))
print("AUROC out kmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_kmnist), np.std(auroc_out_kmnist)))

Laplace Bridge time in: 0.016 with std 0.000
Laplace Bridge time out fmnist: 0.016 with std 0.000
Laplace Bridge time out notmnist: 0.029 with std 0.000
Laplace Bridge time out kmnist: 0.016 with std 0.000
accuracy: 0.990 with std 0.002
MMC in: 0.987 with std 0.001
MMC out fmnist: 0.377 with std 0.019
MMC out notmnist: 0.644 with std 0.018
MMC out kmnist: 0.630 with std 0.018
AUROC out fmnist: 0.994 with std 0.002
AUROC out notmnist: 0.962 with std 0.007
AUROC out kmnist: 0.975 with std 0.004


# additional Calculations for the Dirichlet

In [None]:
from scipy.special import digamma, loggamma

def beta_function(alpha):
    return(np.exp(np.sum([loggamma(a_i) for a_i in alpha]) - loggamma(np.sum(alpha))))

def alphas_norm(alphas):
    alphas = np.array(alphas)
    return(alphas/alphas.sum(axis=1).reshape(-1,1))

def alphas_variance(alphas):
    alphas = np.array(alphas)
    norm = alphas_norm(alphas)
    nom = norm * (1 - norm)
    den = alphas.sum(axis=1).reshape(-1,1) + 1
    return(nom/den)

def log_beta_function(alpha):
    return(np.sum([loggamma(a_i) for a_i in alpha]) - loggamma(np.sum(alpha)))

def alphas_entropy(alphas):
    K = len(alphas[0])
    alphas = np.array(alphas)
    entropy = []
    for x in alphas:
        B = log_beta_function(x)
        alpha_0 = np.sum(x)
        C = (alpha_0 - K)*digamma(alpha_0)
        D = np.sum((x-1)*digamma(x))
        entropy.append(B + C - D)
    
    return(np.array(entropy))
        

def alphas_log_prob(alphas):
    alphas = np.array(alphas)
    dig_sum = digamma(alphas.sum(axis=1).reshape(-1,1))
    log_prob = digamma(alphas) - dig_sum
    return(log_prob)

def auroc_entropy(alphas_in, alphas_out):
    
    entropy_in = alphas_entropy(alphas_in)
    entropy_out = alphas_entropy(alphas_out)
    labels = np.zeros(len(entropy_in)+len(entropy_out), dtype='int32')
    labels[:len(entropy_in)] = 1
    examples = np.concatenate([entropy_in, entropy_out])
    auroc_ent = roc_auc_score(labels, examples)
    return(auroc_ent)

def auroc_variance(alphas_in, alphas_out, method='mean'):
    
    if method=='mean':
        variance_in = alphas_variance(alphas_in).mean(1)
        variance_out = alphas_variance(alphas_out).mean(1)
    elif method=='max':
        variance_in = alphas_variance(alphas_in).max(1)
        variance_out = alphas_variance(alphas_out).max(1)
    labels = np.zeros(len(variance_in)+len(variance_out), dtype='int32')
    labels[:len(variance_in)] = 1
    examples = np.concatenate([variance_in, variance_out])
    auroc_ent = roc_auc_score(labels, examples)
    return(auroc_ent)

In [None]:
print("auroc entropy: MNIST in, FMNIST out: ", 1 - auroc_entropy(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_FMNIST_DIR_LPA))
print("auroc entropy: MNIST in, notMNIST out: ", 1 - auroc_entropy(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_notMNIST_DIR_LPA))
print("auroc entropy: MNIST in, KMNIST out: ", 1 - auroc_entropy(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_KMNIST_DIR_LPA))

In [None]:
print("auroc variance: MNIST in, FMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_FMNIST_DIR_LPA, method='mean'))
print("auroc variance: MNIST in, notMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_notMNIST_DIR_LPA, method='mean'))
print("auroc variance: MNIST in, KMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_KMNIST_DIR_LPA, method='mean'))

In [None]:
print("auroc variance: MNIST in, FMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_FMNIST_DIR_LPA, method='max'))
print("auroc variance: MNIST in, notMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_notMNIST_DIR_LPA, method='max'))
print("auroc variance: MNIST in, KMNIST out: ", 1-auroc_variance(alphas_in=mnist_test_in_DIR_LPA, alphas_out=mnist_test_out_KMNIST_DIR_LPA, method='max'))