In [1]:
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.autograd import Variable
import numpy as np
#import input_data
from sklearn.utils import shuffle as skshuffle
from math import *
from backpack import backpack, extend
from backpack.extensions import KFAC, DiagHessian, DiagGGNMC
from sklearn.metrics import roc_auc_score
import scipy
from tqdm import tqdm, trange
from bpjacext import NetJac
import pytest
from DirLPA_utils import * 
import time
from torch.distributions.normal import Normal

import matplotlib.pyplot as plt

s = 123
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed(s)

#NOTE: DO NOT RUN THIS CODE: the function NetJac comes from a private repository that is not yet available for you. 
#If you search for it aggressively you might find out my identity. I would therefore prefer if you just looked at the provided results.

In [2]:
def LPADirNN(num_classes=10, num_LL=256):
    
    features = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Conv2d(32, 64, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Flatten(),
        torch.nn.Linear(4 * 4 * 64, num_LL), #changed from 500
        torch.nn.Linear(num_LL, num_classes)  #changed from 500
    )
    return(features)

In [3]:
BATCH_SIZE_TRAIN_MNIST = 128
BATCH_SIZE_TEST_MNIST = 128
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [4]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=True,
        download=True,
        transform=MNIST_transform)

mnist_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform)

mnist_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [5]:
mnist_model = LPADirNN(num_LL=256)#.cuda()
loss_function = torch.nn.CrossEntropyLoss()

mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "weights/mnist_test_6iter_10c_simpleCNN_256.pth"

In [6]:
#Training routine

def train(model, train_loader, optimizer, max_iter, path, verbose=True):
    max_len = len(train_loader)

    for iter in range(max_iter):
        for batch_idx, (x, y) in enumerate(train_loader):
            
            x, y = x, y#.cuda()
            
            output = model(x)

            accuracy = get_accuracy(output, y)

            loss = loss_function(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if verbose and batch_idx % 50 == 0:
                print(
                    "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                    "Minibatch Loss %.3f  " % (loss) +
                    "Accuracy %.0f" % (accuracy * 100) + "%"
                )

    print("saving model at: {}".format(path))
    torch.save(mnist_model.state_dict(), path)

In [7]:
#train(mnist_model, mnist_train_loader, mnist_train_optimizer, MAX_ITER_MNIST, MNIST_PATH, verbose=True)

In [8]:
#predict in distribution
MNIST_PATH = "weights/mnist_test_6iter_10c_simpleCNN_256.pth"

#mnist_model = LPADirNN(x=100)
mnist_model = LPADirNN(num_LL=256)#.cuda()
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

max_len = len(mnist_test_loader)
for batch_idx, (x, y) in enumerate(mnist_test_loader):

    x, y = x, y#.cuda()
    output = mnist_model(x)

    accuracy = get_accuracy(output, y)
    if batch_idx % 10 == 0:
        print(
            "Batch {}/{} \t".format(batch_idx, max_len) + 
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )
    acc.append(accuracy)

avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: weights/mnist_test_6iter_10c_simpleCNN_256.pth
Batch 0/79 	Accuracy 100%
Batch 10/79 	Accuracy 98%
Batch 20/79 	Accuracy 98%
Batch 30/79 	Accuracy 99%
Batch 40/79 	Accuracy 100%
Batch 50/79 	Accuracy 100%
Batch 60/79 	Accuracy 100%
Batch 70/79 	Accuracy 100%
overall test accuracy on MNIST: 99.20 %


In [9]:
# get all weight dists
def get_Hessian_NN(model, train_loader, prec0, device='cpu', verbose=True):
    lossfunc = torch.nn.CrossEntropyLoss()
    #lossfunc = torch.nn.BCELoss()

    extend(lossfunc, debug=False)
    extend(model, debug=False)

    Hessian_diag = []
    for param in model.parameters():
        ps = param.size()
        print("parameter size: ", ps)
        Hessian_diag.append(torch.zeros(ps, device=device))
        #print(param.numel())

    var0 = 1/prec0
    max_len = len(train_loader)

    with backpack(DiagHessian()):

        for batch_idx, (x, y) in enumerate(train_loader):

            if device == 'cuda':
                x, y = x.float().cuda(), y.long().cuda()

            model.zero_grad()
            lossfunc(model(x), y).backward()

            with torch.no_grad():
                # Hessian of weight
                for idx, param in enumerate(model.parameters()):

                    H_ = param.diag_h
                    #add prior here
                    H_ += var0 * torch.ones(H_.size())

                    rho = 1-1/(batch_idx+1)

                    Hessian_diag[idx] = rho*Hessian_diag[idx] + (1-rho)* H_
            
            if verbose and batch_idx%20==0:
                print("Batch: {}/{}".format(batch_idx, max_len))
    
    return(Hessian_diag)

In [10]:
# sample from weights
def get_param_dists(model, Hessian):
    param_dists = []
    for idx, param in enumerate(model.parameters()):
        ps = param.size()
        mu = param.view(-1)
        Sigma = Hessian[idx].view(-1)
        print("mu size: ", mu.size())
        print("Sigma size: ", Sigma.size())
        dist = Normal(mu, Sigma)
        param_dists.append(dist)
        
    return(param_dists)

## Load additional data

In [12]:
BATCH_SIZE_TEST_FMNIST = 128
BATCH_SIZE_TEST_KMNIST = 128

In [13]:
FMNIST_test = torchvision.datasets.FashionMNIST(
        '~/data/fmnist', train=False, download=True,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [14]:
KMNIST_test = torchvision.datasets.KMNIST(
        '~/data/kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=BATCH_SIZE_TEST_KMNIST, shuffle=False)

In [15]:
"""Load notMNIST"""

import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from matplotlib.pyplot import imread
from torch import Tensor

"""
Loads the train/test set. 
Every image in the dataset is 28x28 pixels and the labels are numbered from 0-9
for A-J respectively.
Set root to point to the Train/Test folders.
"""

# Creating a sub class of torch.utils.data.dataset.Dataset
class notMNIST(Dataset):

    # The init method is called when this class will be instantiated
    def __init__(self, root, transform):
        
        #super(notMNIST, self).__init__(root, transform=transform)

        self.transform = transform
        
        Images, Y = [], []
        folders = os.listdir(root)

        for folder in folders:
            folder_path = os.path.join(root, folder)
            for ims in os.listdir(folder_path):
                try:
                    img_path = os.path.join(folder_path, ims)
                    Images.append(np.array(imread(img_path)))
                    Y.append(ord(folder) - 65)  # Folders are A-J so labels will be 0-9
                except:
                    # Some images in the dataset are damaged
                    print("File {}/{} is broken".format(folder, ims))
        data = [(x, y) for x, y in zip(Images, Y)]
        self.data = data
        self.targets = torch.Tensor(Y)

    # The number of items in the dataset
    def __len__(self):
        return len(self.data)

    # The Dataloader is a generator that repeatedly calls the getitem method.
    # getitem is supposed to return (X, Y) for the specified index.
    def __getitem__(self, index):
        img = self.data[index][0]

        if self.transform is not None:
            img = self.transform(img)
            
        # Input for Conv2D should be Channels x Height x Width
        img_tensor = Tensor(img).view(1, 28, 28).float()
        label = self.data[index][1]
        return (img_tensor, label)

In [16]:
#root = os.path.abspath('~/data')
root = os.path.expanduser('~/data')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
not_mnist_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=BATCH_SIZE_TEST_KMNIST,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


In [17]:
def get_in_dist_values(py_in, targets):
    acc_in = np.mean(np.argmax(py_in, 1) == targets)
    prob_correct = np.choose(targets, py_in.T).mean()
    average_entropy = -np.sum(py_in*np.log(py_in+1e-8), axis=1).mean()
    MMC = py_in.max(1).mean()
    return(acc_in, prob_correct, average_entropy, MMC)
    
def get_out_dist_values(py_in, py_out, targets):
    average_entropy = -np.sum(py_out*np.log(py_out+1e-8), axis=1).mean()
    acc_out = np.mean(np.argmax(py_out, 1) == targets)
    prob_correct = np.choose(targets, py_out.T).mean()
    labels = np.zeros(len(py_in)+len(py_out), dtype='int32')
    labels[:len(py_in)] = 1
    examples = np.concatenate([py_in.max(1), py_out.max(1)])
    auroc = roc_auc_score(labels, examples)
    MMC = py_out.max(1).mean()
    return(acc_out, prob_correct, average_entropy, MMC, auroc)

def print_in_dist_values(acc_in, prob_correct, average_entropy, MMC, train='mnist', method='LLLA-KF'):
    
    print(f'[In, {method}, {train}] Accuracy: {acc_in:.3f}; average entropy: {average_entropy:.3f}; \
    MMC: {MMC:.3f}; Prob @ correct: {prob_correct:.3f}')


def print_out_dist_values(acc_out, prob_correct, average_entropy, MMC, auroc, train='mnist', test='FMNIST', method='LLLA-KF'):
   
    print(f'[Out-{test}, {method}, {train}] Accuracy: {acc_out:.3f}; Average entropy: {average_entropy:.3f};\
    MMC: {MMC:.3f}; AUROC: {auroc:.3f}; Prob @ correct: {prob_correct:.3f}')

# MAP estimate

In [18]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [19]:
mnist_test_in_MAP = predict_MAP(mnist_model, mnist_test_loader, cuda=False).cpu().numpy()
mnist_test_out_fmnist_MAP = predict_MAP(mnist_model, FMNIST_test_loader, cuda=False).cpu().numpy()
mnist_test_out_notMNIST_MAP = predict_MAP(mnist_model, not_mnist_test_loader, cuda=False).cpu().numpy()
mnist_test_out_KMNIST_MAP = predict_MAP(mnist_model, KMNIST_test_loader, cuda=False).cpu().numpy()

In [20]:
acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP = get_in_dist_values(mnist_test_in_MAP, targets)
acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_fmnist_MAP, targets_FMNIST)
acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_notMNIST_MAP, targets_notMNIST)
acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_KMNIST_MAP, targets_KMNIST)

In [21]:
print_in_dist_values(acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP, 'mnist', 'MAP')
print_out_dist_values(acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP, 'FMNIST', 'MAP')
print_out_dist_values(acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP, 'notMNIST', 'MAP')
print_out_dist_values(acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP, 'KMNIST', 'MAP')

[In, MAP, mnist] Accuracy: 0.992; average entropy: 0.033;     MMC: 0.990; Prob @ correct: 0.985
[Out-MAP, LLLA-KF, FMNIST] Accuracy: 0.142; Average entropy: 1.402;    MMC: 0.500; AUROC: 0.993; Prob @ correct: 0.125
[Out-MAP, LLLA-KF, notMNIST] Accuracy: 0.126; Average entropy: 0.828;    MMC: 0.699; AUROC: 0.971; Prob @ correct: 0.127
[Out-MAP, LLLA-KF, KMNIST] Accuracy: 0.096; Average entropy: 0.902;    MMC: 0.674; AUROC: 0.979; Prob @ correct: 0.094


In [22]:
import numpy as np

In [23]:
#MAP estimate
#seeds are 123,124,125,126,127
acc_in = [0.991, 0.990, 0.993, 0.988, 0.989]
mmc_in = [0.988, 0.990, 0.989, 0.989, 0.989]
mmc_out_fmnist = [0.516, 0.571, 0.534, 0.554, 0.513]
mmc_out_notmnist = [0.692, 0.731, 0.696, 0.702, 0.709]
mmc_out_kmnist = [0.678, 0.697, 0.659, 0.700, 0.687]

auroc_out_fmnist = [0.989, 0.990, 0.990, 0.989, 0.992]
auroc_out_notmnist = [0.964, 0.942, 0.959, 0.952, 0.952]
auroc_out_kmnist = [0.970, 0.977, 0.980, 0.967, 0.974]

print("accuracy: {:.03f} with std {:.03f}".format(np.mean(acc_in), np.std(acc_in)))

print("MMC in: {:.03f} with std {:.03f}".format(np.mean(mmc_in), np.std(mmc_in)))
print("MMC out fmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_fmnist), np.std(mmc_out_fmnist)))
print("MMC out notmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_notmnist), np.std(mmc_out_notmnist)))
print("MMC out kmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_kmnist), np.std(mmc_out_kmnist)))

print("AUROC out fmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_fmnist), np.std(auroc_out_fmnist)))
print("AUROC out notmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_notmnist), np.std(auroc_out_notmnist)))
print("AUROC out kmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_kmnist), np.std(auroc_out_kmnist)))

accuracy: 0.990 with std 0.002
MMC in: 0.989 with std 0.001
MMC out fmnist: 0.538 with std 0.022
MMC out notmnist: 0.706 with std 0.014
MMC out kmnist: 0.684 with std 0.015
AUROC out fmnist: 0.990 with std 0.001
AUROC out notmnist: 0.954 with std 0.007
AUROC out kmnist: 0.974 with std 0.005


# Diag Hessian sample all weights estimate

In [24]:
HessianNN_mnist = get_Hessian_NN(mnist_model, mnist_train_loader, prec0=200, device='cpu', verbose=True)

parameter size:  torch.Size([32, 1, 5, 5])
parameter size:  torch.Size([32])
parameter size:  torch.Size([64, 32, 5, 5])
parameter size:  torch.Size([64])
parameter size:  torch.Size([256, 1024])
parameter size:  torch.Size([256])
parameter size:  torch.Size([10, 256])
parameter size:  torch.Size([10])
Batch: 0/469
Batch: 20/469
Batch: 40/469
Batch: 60/469
Batch: 80/469
Batch: 100/469
Batch: 120/469
Batch: 140/469
Batch: 160/469
Batch: 180/469
Batch: 200/469
Batch: 220/469
Batch: 240/469
Batch: 260/469
Batch: 280/469
Batch: 300/469
Batch: 320/469
Batch: 340/469
Batch: 360/469
Batch: 380/469
Batch: 400/469
Batch: 420/469
Batch: 440/469
Batch: 460/469


In [25]:
mnist_param_dist = get_param_dists(mnist_model, HessianNN_mnist)

mu size:  torch.Size([800])
Sigma size:  torch.Size([800])
mu size:  torch.Size([32])
Sigma size:  torch.Size([32])
mu size:  torch.Size([51200])
Sigma size:  torch.Size([51200])
mu size:  torch.Size([64])
Sigma size:  torch.Size([64])
mu size:  torch.Size([262144])
Sigma size:  torch.Size([262144])
mu size:  torch.Size([256])
Sigma size:  torch.Size([256])
mu size:  torch.Size([2560])
Sigma size:  torch.Size([2560])
mu size:  torch.Size([10])
Sigma size:  torch.Size([10])


In [29]:
# predict with different samples
def get_stacked_results(param_dists, data_loader, network, num_samples=100, verbose=False):

    results = []

    test_model = network(num_classes=10)
    test_model.eval()
    #new_state_dict = fmnist_model.state_dict()
    new_state_dict = dict()
    for s in range(num_samples):
        #torch.manual_seed(s)
        for idx, (p_name, param) in enumerate(test_model.named_parameters()):
            ps = param.size()
            new_weights = param_dists[idx].sample().view(ps)
            new_state_dict[p_name] = new_weights

        #print(new_state_dict.keys())
        test_model.load_state_dict(new_state_dict)
        
        inter_results = []
        for batch_idx, (x, y) in enumerate(data_loader):
            y_pred_logits = test_model(x).detach()
            #print(y_pred_logits)
            y_pred = F.softmax(y_pred_logits, dim=1).detach()
            inter_results.append(y_pred)
            if verbose and batch_idx % 100 == 0:
                print("s: {}; batch_idx: {}".format(s, batch_idx))
                break

    
    stacked_results = torch.stack(results)
    return(stacked_results)

In [None]:
stacked_results_mnist = get_stacked_results(param_dists=mnist_param_dist, data_loader=mnist_train_loader\
                                            ,network=LPADirNN, num_samples=100, verbose=True)

s: 0; batch_idx: 0


In [None]:
#Diag Sampling
#seeds are 123,124,125,126,127
time_lpb_in = [6.78413, 6.67228, 6.51112,6.44895, 6.67633]
time_lpb_out_fmnist = [6.77055, 6.64844, 6.47705,6.44392,6.67161]
time_lpb_out_notmnist = [12.65939, 12.49839, 11.99764, 12.07371, 12.42572]
time_lpb_out_kmnist = [6.79133, 6.73237, 6.46135, 6.46563, 6.68149]

acc_in = [0.991, 0.990, 0.993, 0.988, 0.990]
mmc_in = [0.928, 0.938, 0.927, 0.924, 0.942]
mmc_out_fmnist = [0.397, 0.426, 0.406, 0.406, 0.401]
mmc_out_notmnist = [0.517, 0.560, 0.526, 0.518, 0.554]
mmc_out_kmnist = [0.514, 0.503, 0.475, 0.497, 0.512]

auroc_out_fmnist = [0.986, 0.990, 0.988, 0.990, 0.992]
auroc_out_notmnist = [0.968, 0.948, 0.959, 0.958, 0.959]
auroc_out_kmnist = [0.967, 0.978, 0.980, 0.970, 0.975]

print("Sampling Bridge time in: {:.03f} with std {:.03f}".format(np.mean(time_lpb_in), np.std(time_lpb_in)))
print("Sampling Bridge time out fmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_fmnist), np.std(time_lpb_out_fmnist)))
print("Sampling Bridge time out notmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_notmnist), np.std(time_lpb_out_notmnist)))
print("Sampling Bridge time out kmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_kmnist), np.std(time_lpb_out_kmnist)))

print("accuracy: {:.03f} with std {:.03f}".format(np.mean(acc_in), np.std(acc_in)))

print("MMC in: {:.03f} with std {:.03f}".format(np.mean(mmc_in), np.std(mmc_in)))
print("MMC out fmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_fmnist), np.std(mmc_out_fmnist)))
print("MMC out notmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_notmnist), np.std(mmc_out_notmnist)))
print("MMC out kmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_kmnist), np.std(mmc_out_kmnist)))

print("AUROC out fmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_fmnist), np.std(auroc_out_fmnist)))
print("AUROC out notmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_notmnist), np.std(auroc_out_notmnist)))
print("AUROC out kmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_kmnist), np.std(auroc_out_kmnist)))