In [1]:
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.autograd import Variable
import numpy as np
#import input_data
from sklearn.utils import shuffle as skshuffle
from math import *
from backpack import backpack, extend
from backpack.extensions import KFAC, DiagHessian, DiagGGNMC
from sklearn.metrics import roc_auc_score
import scipy
from tqdm import tqdm, trange
from backpack.core.layers import Flatten
import pytest
from DirLPA_utils import * 

import matplotlib.pyplot as plt

np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x7f4b4190c050>

In [2]:
class LPADirNN(torch.nn.Module):
    def __init__(self, num_classes=10, x=400):
        
        super(LPADirNN, self).__init__()
        
        self.features = torch.nn.Sequential(
            torch.nn.Linear(28*28, x),
            torch.nn.ReLU(),
            torch.nn.Linear(x,x),
            torch.nn.ReLU()
        )
        self.fc = torch.nn.Linear(x, num_classes)
        
    def forward(self, x):
        
        x = x.view(-1, 28*28)
        x = self.features(x)
        out = self.fc(x)
        
        return(out)
        

In [3]:
BATCH_SIZE_TRAIN_MNIST = 128
BATCH_SIZE_TEST_MNIST = 128
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [4]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        './data',
        train=True,
        download=True,
        transform=MNIST_transform)

mnist_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        './data',
        train=False,
        download=False,
        transform=MNIST_transform)

mnist_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [5]:
mnist_model = LPADirNN(x=100)
loss_function = torch.nn.CrossEntropyLoss()

#mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=LR_TRAIN_MNIST)
mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "models/mnist_test_6iter_10c_simpleNN_100.pth"

In [6]:
#Training routine

for iter in range(MAX_ITER_MNIST):
    for batch_idx, (x, y) in enumerate(mnist_train_loader):
        max_len = int(np.ceil(len(mnist_train_loader.dataset)/BATCH_SIZE_TRAIN_MNIST))
        output = mnist_model(x)

        accuracy = get_accuracy(output, y)

        loss = loss_function(output, y)
        loss.backward()
        mnist_train_optimizer.step()
        mnist_train_optimizer.zero_grad()

        print(
            "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
            "Minibatch Loss %.3f  " % (loss) +
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )

print("saving model at: {}".format(MNIST_PATH))
torch.save(mnist_model.state_dict(), MNIST_PATH)

Iteration 0; 0/469 	Minibatch Loss 2.305  Accuracy 9%
Iteration 0; 1/469 	Minibatch Loss 2.288  Accuracy 18%
Iteration 0; 2/469 	Minibatch Loss 2.279  Accuracy 14%
Iteration 0; 3/469 	Minibatch Loss 2.253  Accuracy 10%
Iteration 0; 4/469 	Minibatch Loss 2.228  Accuracy 23%
Iteration 0; 5/469 	Minibatch Loss 2.217  Accuracy 25%
Iteration 0; 6/469 	Minibatch Loss 2.183  Accuracy 35%
Iteration 0; 7/469 	Minibatch Loss 2.173  Accuracy 38%
Iteration 0; 8/469 	Minibatch Loss 2.123  Accuracy 55%
Iteration 0; 9/469 	Minibatch Loss 2.063  Accuracy 58%
Iteration 0; 10/469 	Minibatch Loss 2.040  Accuracy 59%
Iteration 0; 11/469 	Minibatch Loss 2.000  Accuracy 57%
Iteration 0; 12/469 	Minibatch Loss 1.972  Accuracy 62%
Iteration 0; 13/469 	Minibatch Loss 1.926  Accuracy 61%
Iteration 0; 14/469 	Minibatch Loss 1.901  Accuracy 58%
Iteration 0; 15/469 	Minibatch Loss 1.804  Accuracy 61%
Iteration 0; 16/469 	Minibatch Loss 1.762  Accuracy 62%
Iteration 0; 17/469 	Minibatch Loss 1.745  Accuracy 56%
Ite

Iteration 0; 151/469 	Minibatch Loss 0.458  Accuracy 86%
Iteration 0; 152/469 	Minibatch Loss 0.232  Accuracy 96%
Iteration 0; 153/469 	Minibatch Loss 0.276  Accuracy 91%
Iteration 0; 154/469 	Minibatch Loss 0.425  Accuracy 87%
Iteration 0; 155/469 	Minibatch Loss 0.464  Accuracy 88%
Iteration 0; 156/469 	Minibatch Loss 0.298  Accuracy 91%
Iteration 0; 157/469 	Minibatch Loss 0.327  Accuracy 90%
Iteration 0; 158/469 	Minibatch Loss 0.306  Accuracy 91%
Iteration 0; 159/469 	Minibatch Loss 0.339  Accuracy 88%
Iteration 0; 160/469 	Minibatch Loss 0.249  Accuracy 93%
Iteration 0; 161/469 	Minibatch Loss 0.203  Accuracy 96%
Iteration 0; 162/469 	Minibatch Loss 0.286  Accuracy 93%
Iteration 0; 163/469 	Minibatch Loss 0.274  Accuracy 88%
Iteration 0; 164/469 	Minibatch Loss 0.404  Accuracy 91%
Iteration 0; 165/469 	Minibatch Loss 0.241  Accuracy 92%
Iteration 0; 166/469 	Minibatch Loss 0.340  Accuracy 90%
Iteration 0; 167/469 	Minibatch Loss 0.418  Accuracy 89%
Iteration 0; 168/469 	Minibatch

Iteration 0; 297/469 	Minibatch Loss 0.258  Accuracy 92%
Iteration 0; 298/469 	Minibatch Loss 0.358  Accuracy 89%
Iteration 0; 299/469 	Minibatch Loss 0.189  Accuracy 93%
Iteration 0; 300/469 	Minibatch Loss 0.276  Accuracy 91%
Iteration 0; 301/469 	Minibatch Loss 0.401  Accuracy 88%
Iteration 0; 302/469 	Minibatch Loss 0.358  Accuracy 90%
Iteration 0; 303/469 	Minibatch Loss 0.279  Accuracy 94%
Iteration 0; 304/469 	Minibatch Loss 0.311  Accuracy 95%
Iteration 0; 305/469 	Minibatch Loss 0.303  Accuracy 91%
Iteration 0; 306/469 	Minibatch Loss 0.170  Accuracy 95%
Iteration 0; 307/469 	Minibatch Loss 0.231  Accuracy 95%
Iteration 0; 308/469 	Minibatch Loss 0.316  Accuracy 90%
Iteration 0; 309/469 	Minibatch Loss 0.200  Accuracy 94%
Iteration 0; 310/469 	Minibatch Loss 0.449  Accuracy 90%
Iteration 0; 311/469 	Minibatch Loss 0.255  Accuracy 91%
Iteration 0; 312/469 	Minibatch Loss 0.376  Accuracy 90%
Iteration 0; 313/469 	Minibatch Loss 0.203  Accuracy 94%
Iteration 0; 314/469 	Minibatch

Iteration 0; 442/469 	Minibatch Loss 0.241  Accuracy 91%
Iteration 0; 443/469 	Minibatch Loss 0.112  Accuracy 97%
Iteration 0; 444/469 	Minibatch Loss 0.129  Accuracy 95%
Iteration 0; 445/469 	Minibatch Loss 0.155  Accuracy 96%
Iteration 0; 446/469 	Minibatch Loss 0.144  Accuracy 96%
Iteration 0; 447/469 	Minibatch Loss 0.168  Accuracy 95%
Iteration 0; 448/469 	Minibatch Loss 0.228  Accuracy 95%
Iteration 0; 449/469 	Minibatch Loss 0.182  Accuracy 93%
Iteration 0; 450/469 	Minibatch Loss 0.196  Accuracy 95%
Iteration 0; 451/469 	Minibatch Loss 0.162  Accuracy 96%
Iteration 0; 452/469 	Minibatch Loss 0.235  Accuracy 92%
Iteration 0; 453/469 	Minibatch Loss 0.203  Accuracy 94%
Iteration 0; 454/469 	Minibatch Loss 0.235  Accuracy 95%
Iteration 0; 455/469 	Minibatch Loss 0.241  Accuracy 93%
Iteration 0; 456/469 	Minibatch Loss 0.245  Accuracy 95%
Iteration 0; 457/469 	Minibatch Loss 0.232  Accuracy 94%
Iteration 0; 458/469 	Minibatch Loss 0.267  Accuracy 94%
Iteration 0; 459/469 	Minibatch

Iteration 1; 121/469 	Minibatch Loss 0.138  Accuracy 95%
Iteration 1; 122/469 	Minibatch Loss 0.153  Accuracy 95%
Iteration 1; 123/469 	Minibatch Loss 0.118  Accuracy 97%
Iteration 1; 124/469 	Minibatch Loss 0.213  Accuracy 92%
Iteration 1; 125/469 	Minibatch Loss 0.298  Accuracy 92%
Iteration 1; 126/469 	Minibatch Loss 0.130  Accuracy 97%
Iteration 1; 127/469 	Minibatch Loss 0.316  Accuracy 92%
Iteration 1; 128/469 	Minibatch Loss 0.206  Accuracy 95%
Iteration 1; 129/469 	Minibatch Loss 0.195  Accuracy 92%
Iteration 1; 130/469 	Minibatch Loss 0.128  Accuracy 97%
Iteration 1; 131/469 	Minibatch Loss 0.230  Accuracy 94%
Iteration 1; 132/469 	Minibatch Loss 0.342  Accuracy 91%
Iteration 1; 133/469 	Minibatch Loss 0.259  Accuracy 91%
Iteration 1; 134/469 	Minibatch Loss 0.140  Accuracy 95%
Iteration 1; 135/469 	Minibatch Loss 0.182  Accuracy 96%
Iteration 1; 136/469 	Minibatch Loss 0.264  Accuracy 94%
Iteration 1; 137/469 	Minibatch Loss 0.161  Accuracy 95%
Iteration 1; 138/469 	Minibatch

Iteration 1; 268/469 	Minibatch Loss 0.201  Accuracy 95%
Iteration 1; 269/469 	Minibatch Loss 0.178  Accuracy 93%
Iteration 1; 270/469 	Minibatch Loss 0.156  Accuracy 95%
Iteration 1; 271/469 	Minibatch Loss 0.176  Accuracy 94%
Iteration 1; 272/469 	Minibatch Loss 0.095  Accuracy 98%
Iteration 1; 273/469 	Minibatch Loss 0.123  Accuracy 95%
Iteration 1; 274/469 	Minibatch Loss 0.053  Accuracy 99%
Iteration 1; 275/469 	Minibatch Loss 0.104  Accuracy 97%
Iteration 1; 276/469 	Minibatch Loss 0.228  Accuracy 94%
Iteration 1; 277/469 	Minibatch Loss 0.155  Accuracy 96%
Iteration 1; 278/469 	Minibatch Loss 0.166  Accuracy 94%
Iteration 1; 279/469 	Minibatch Loss 0.200  Accuracy 91%
Iteration 1; 280/469 	Minibatch Loss 0.174  Accuracy 96%
Iteration 1; 281/469 	Minibatch Loss 0.124  Accuracy 95%
Iteration 1; 282/469 	Minibatch Loss 0.245  Accuracy 94%
Iteration 1; 283/469 	Minibatch Loss 0.141  Accuracy 96%
Iteration 1; 284/469 	Minibatch Loss 0.132  Accuracy 96%
Iteration 1; 285/469 	Minibatch

Iteration 1; 423/469 	Minibatch Loss 0.307  Accuracy 90%
Iteration 1; 424/469 	Minibatch Loss 0.135  Accuracy 95%
Iteration 1; 425/469 	Minibatch Loss 0.135  Accuracy 96%
Iteration 1; 426/469 	Minibatch Loss 0.179  Accuracy 95%
Iteration 1; 427/469 	Minibatch Loss 0.123  Accuracy 98%
Iteration 1; 428/469 	Minibatch Loss 0.206  Accuracy 95%
Iteration 1; 429/469 	Minibatch Loss 0.283  Accuracy 95%
Iteration 1; 430/469 	Minibatch Loss 0.222  Accuracy 93%
Iteration 1; 431/469 	Minibatch Loss 0.183  Accuracy 94%
Iteration 1; 432/469 	Minibatch Loss 0.110  Accuracy 98%
Iteration 1; 433/469 	Minibatch Loss 0.148  Accuracy 95%
Iteration 1; 434/469 	Minibatch Loss 0.139  Accuracy 98%
Iteration 1; 435/469 	Minibatch Loss 0.269  Accuracy 93%
Iteration 1; 436/469 	Minibatch Loss 0.111  Accuracy 98%
Iteration 1; 437/469 	Minibatch Loss 0.177  Accuracy 95%
Iteration 1; 438/469 	Minibatch Loss 0.128  Accuracy 95%
Iteration 1; 439/469 	Minibatch Loss 0.306  Accuracy 92%
Iteration 1; 440/469 	Minibatch

Iteration 2; 111/469 	Minibatch Loss 0.076  Accuracy 98%
Iteration 2; 112/469 	Minibatch Loss 0.095  Accuracy 98%
Iteration 2; 113/469 	Minibatch Loss 0.108  Accuracy 98%
Iteration 2; 114/469 	Minibatch Loss 0.043  Accuracy 100%
Iteration 2; 115/469 	Minibatch Loss 0.083  Accuracy 98%
Iteration 2; 116/469 	Minibatch Loss 0.167  Accuracy 94%
Iteration 2; 117/469 	Minibatch Loss 0.078  Accuracy 97%
Iteration 2; 118/469 	Minibatch Loss 0.092  Accuracy 97%
Iteration 2; 119/469 	Minibatch Loss 0.150  Accuracy 96%
Iteration 2; 120/469 	Minibatch Loss 0.088  Accuracy 97%
Iteration 2; 121/469 	Minibatch Loss 0.096  Accuracy 98%
Iteration 2; 122/469 	Minibatch Loss 0.113  Accuracy 98%
Iteration 2; 123/469 	Minibatch Loss 0.115  Accuracy 95%
Iteration 2; 124/469 	Minibatch Loss 0.147  Accuracy 97%
Iteration 2; 125/469 	Minibatch Loss 0.258  Accuracy 93%
Iteration 2; 126/469 	Minibatch Loss 0.085  Accuracy 97%
Iteration 2; 127/469 	Minibatch Loss 0.179  Accuracy 96%
Iteration 2; 128/469 	Minibatc

Iteration 2; 261/469 	Minibatch Loss 0.061  Accuracy 99%
Iteration 2; 262/469 	Minibatch Loss 0.152  Accuracy 95%
Iteration 2; 263/469 	Minibatch Loss 0.115  Accuracy 96%
Iteration 2; 264/469 	Minibatch Loss 0.185  Accuracy 92%
Iteration 2; 265/469 	Minibatch Loss 0.122  Accuracy 98%
Iteration 2; 266/469 	Minibatch Loss 0.126  Accuracy 96%
Iteration 2; 267/469 	Minibatch Loss 0.115  Accuracy 97%
Iteration 2; 268/469 	Minibatch Loss 0.107  Accuracy 96%
Iteration 2; 269/469 	Minibatch Loss 0.122  Accuracy 98%
Iteration 2; 270/469 	Minibatch Loss 0.183  Accuracy 95%
Iteration 2; 271/469 	Minibatch Loss 0.142  Accuracy 95%
Iteration 2; 272/469 	Minibatch Loss 0.121  Accuracy 95%
Iteration 2; 273/469 	Minibatch Loss 0.120  Accuracy 98%
Iteration 2; 274/469 	Minibatch Loss 0.101  Accuracy 98%
Iteration 2; 275/469 	Minibatch Loss 0.137  Accuracy 95%
Iteration 2; 276/469 	Minibatch Loss 0.100  Accuracy 97%
Iteration 2; 277/469 	Minibatch Loss 0.103  Accuracy 97%
Iteration 2; 278/469 	Minibatch

Iteration 2; 405/469 	Minibatch Loss 0.181  Accuracy 95%
Iteration 2; 406/469 	Minibatch Loss 0.116  Accuracy 95%
Iteration 2; 407/469 	Minibatch Loss 0.108  Accuracy 97%
Iteration 2; 408/469 	Minibatch Loss 0.149  Accuracy 96%
Iteration 2; 409/469 	Minibatch Loss 0.226  Accuracy 94%
Iteration 2; 410/469 	Minibatch Loss 0.133  Accuracy 97%
Iteration 2; 411/469 	Minibatch Loss 0.122  Accuracy 97%
Iteration 2; 412/469 	Minibatch Loss 0.060  Accuracy 99%
Iteration 2; 413/469 	Minibatch Loss 0.206  Accuracy 92%
Iteration 2; 414/469 	Minibatch Loss 0.098  Accuracy 98%
Iteration 2; 415/469 	Minibatch Loss 0.091  Accuracy 98%
Iteration 2; 416/469 	Minibatch Loss 0.091  Accuracy 98%
Iteration 2; 417/469 	Minibatch Loss 0.121  Accuracy 96%
Iteration 2; 418/469 	Minibatch Loss 0.120  Accuracy 97%
Iteration 2; 419/469 	Minibatch Loss 0.082  Accuracy 98%
Iteration 2; 420/469 	Minibatch Loss 0.124  Accuracy 95%
Iteration 2; 421/469 	Minibatch Loss 0.087  Accuracy 96%
Iteration 2; 422/469 	Minibatch

Iteration 3; 82/469 	Minibatch Loss 0.091  Accuracy 97%
Iteration 3; 83/469 	Minibatch Loss 0.166  Accuracy 94%
Iteration 3; 84/469 	Minibatch Loss 0.059  Accuracy 98%
Iteration 3; 85/469 	Minibatch Loss 0.054  Accuracy 98%
Iteration 3; 86/469 	Minibatch Loss 0.101  Accuracy 98%
Iteration 3; 87/469 	Minibatch Loss 0.123  Accuracy 98%
Iteration 3; 88/469 	Minibatch Loss 0.089  Accuracy 96%
Iteration 3; 89/469 	Minibatch Loss 0.090  Accuracy 98%
Iteration 3; 90/469 	Minibatch Loss 0.063  Accuracy 98%
Iteration 3; 91/469 	Minibatch Loss 0.120  Accuracy 96%
Iteration 3; 92/469 	Minibatch Loss 0.104  Accuracy 97%
Iteration 3; 93/469 	Minibatch Loss 0.064  Accuracy 98%
Iteration 3; 94/469 	Minibatch Loss 0.047  Accuracy 98%
Iteration 3; 95/469 	Minibatch Loss 0.088  Accuracy 97%
Iteration 3; 96/469 	Minibatch Loss 0.081  Accuracy 98%
Iteration 3; 97/469 	Minibatch Loss 0.136  Accuracy 95%
Iteration 3; 98/469 	Minibatch Loss 0.081  Accuracy 98%
Iteration 3; 99/469 	Minibatch Loss 0.112  Accur

Iteration 3; 228/469 	Minibatch Loss 0.059  Accuracy 98%
Iteration 3; 229/469 	Minibatch Loss 0.094  Accuracy 98%
Iteration 3; 230/469 	Minibatch Loss 0.070  Accuracy 98%
Iteration 3; 231/469 	Minibatch Loss 0.070  Accuracy 98%
Iteration 3; 232/469 	Minibatch Loss 0.119  Accuracy 98%
Iteration 3; 233/469 	Minibatch Loss 0.121  Accuracy 95%
Iteration 3; 234/469 	Minibatch Loss 0.110  Accuracy 98%
Iteration 3; 235/469 	Minibatch Loss 0.146  Accuracy 97%
Iteration 3; 236/469 	Minibatch Loss 0.176  Accuracy 95%
Iteration 3; 237/469 	Minibatch Loss 0.233  Accuracy 92%
Iteration 3; 238/469 	Minibatch Loss 0.051  Accuracy 99%
Iteration 3; 239/469 	Minibatch Loss 0.078  Accuracy 97%
Iteration 3; 240/469 	Minibatch Loss 0.118  Accuracy 96%
Iteration 3; 241/469 	Minibatch Loss 0.104  Accuracy 97%
Iteration 3; 242/469 	Minibatch Loss 0.097  Accuracy 97%
Iteration 3; 243/469 	Minibatch Loss 0.155  Accuracy 95%
Iteration 3; 244/469 	Minibatch Loss 0.143  Accuracy 96%
Iteration 3; 245/469 	Minibatch

Iteration 3; 378/469 	Minibatch Loss 0.146  Accuracy 96%
Iteration 3; 379/469 	Minibatch Loss 0.106  Accuracy 95%
Iteration 3; 380/469 	Minibatch Loss 0.134  Accuracy 95%
Iteration 3; 381/469 	Minibatch Loss 0.096  Accuracy 96%
Iteration 3; 382/469 	Minibatch Loss 0.090  Accuracy 96%
Iteration 3; 383/469 	Minibatch Loss 0.114  Accuracy 96%
Iteration 3; 384/469 	Minibatch Loss 0.043  Accuracy 99%
Iteration 3; 385/469 	Minibatch Loss 0.127  Accuracy 96%
Iteration 3; 386/469 	Minibatch Loss 0.057  Accuracy 99%
Iteration 3; 387/469 	Minibatch Loss 0.077  Accuracy 98%
Iteration 3; 388/469 	Minibatch Loss 0.074  Accuracy 97%
Iteration 3; 389/469 	Minibatch Loss 0.122  Accuracy 95%
Iteration 3; 390/469 	Minibatch Loss 0.048  Accuracy 98%
Iteration 3; 391/469 	Minibatch Loss 0.104  Accuracy 98%
Iteration 3; 392/469 	Minibatch Loss 0.111  Accuracy 94%
Iteration 3; 393/469 	Minibatch Loss 0.166  Accuracy 95%
Iteration 3; 394/469 	Minibatch Loss 0.087  Accuracy 97%
Iteration 3; 395/469 	Minibatch

Iteration 4; 61/469 	Minibatch Loss 0.043  Accuracy 99%
Iteration 4; 62/469 	Minibatch Loss 0.084  Accuracy 98%
Iteration 4; 63/469 	Minibatch Loss 0.097  Accuracy 98%
Iteration 4; 64/469 	Minibatch Loss 0.104  Accuracy 97%
Iteration 4; 65/469 	Minibatch Loss 0.043  Accuracy 100%
Iteration 4; 66/469 	Minibatch Loss 0.136  Accuracy 97%
Iteration 4; 67/469 	Minibatch Loss 0.063  Accuracy 99%
Iteration 4; 68/469 	Minibatch Loss 0.103  Accuracy 96%
Iteration 4; 69/469 	Minibatch Loss 0.149  Accuracy 95%
Iteration 4; 70/469 	Minibatch Loss 0.087  Accuracy 97%
Iteration 4; 71/469 	Minibatch Loss 0.029  Accuracy 100%
Iteration 4; 72/469 	Minibatch Loss 0.076  Accuracy 98%
Iteration 4; 73/469 	Minibatch Loss 0.107  Accuracy 96%
Iteration 4; 74/469 	Minibatch Loss 0.152  Accuracy 96%
Iteration 4; 75/469 	Minibatch Loss 0.153  Accuracy 96%
Iteration 4; 76/469 	Minibatch Loss 0.117  Accuracy 98%
Iteration 4; 77/469 	Minibatch Loss 0.225  Accuracy 93%
Iteration 4; 78/469 	Minibatch Loss 0.092  Acc

Iteration 4; 208/469 	Minibatch Loss 0.127  Accuracy 95%
Iteration 4; 209/469 	Minibatch Loss 0.078  Accuracy 98%
Iteration 4; 210/469 	Minibatch Loss 0.227  Accuracy 92%
Iteration 4; 211/469 	Minibatch Loss 0.087  Accuracy 97%
Iteration 4; 212/469 	Minibatch Loss 0.042  Accuracy 98%
Iteration 4; 213/469 	Minibatch Loss 0.057  Accuracy 99%
Iteration 4; 214/469 	Minibatch Loss 0.104  Accuracy 95%
Iteration 4; 215/469 	Minibatch Loss 0.128  Accuracy 95%
Iteration 4; 216/469 	Minibatch Loss 0.086  Accuracy 98%
Iteration 4; 217/469 	Minibatch Loss 0.120  Accuracy 96%
Iteration 4; 218/469 	Minibatch Loss 0.053  Accuracy 98%
Iteration 4; 219/469 	Minibatch Loss 0.077  Accuracy 98%
Iteration 4; 220/469 	Minibatch Loss 0.049  Accuracy 98%
Iteration 4; 221/469 	Minibatch Loss 0.127  Accuracy 96%
Iteration 4; 222/469 	Minibatch Loss 0.067  Accuracy 98%
Iteration 4; 223/469 	Minibatch Loss 0.059  Accuracy 99%
Iteration 4; 224/469 	Minibatch Loss 0.116  Accuracy 96%
Iteration 4; 225/469 	Minibatch

Iteration 4; 360/469 	Minibatch Loss 0.049  Accuracy 98%
Iteration 4; 361/469 	Minibatch Loss 0.113  Accuracy 97%
Iteration 4; 362/469 	Minibatch Loss 0.043  Accuracy 98%
Iteration 4; 363/469 	Minibatch Loss 0.102  Accuracy 97%
Iteration 4; 364/469 	Minibatch Loss 0.133  Accuracy 95%
Iteration 4; 365/469 	Minibatch Loss 0.100  Accuracy 98%
Iteration 4; 366/469 	Minibatch Loss 0.072  Accuracy 98%
Iteration 4; 367/469 	Minibatch Loss 0.072  Accuracy 98%
Iteration 4; 368/469 	Minibatch Loss 0.055  Accuracy 99%
Iteration 4; 369/469 	Minibatch Loss 0.042  Accuracy 98%
Iteration 4; 370/469 	Minibatch Loss 0.071  Accuracy 97%
Iteration 4; 371/469 	Minibatch Loss 0.148  Accuracy 95%
Iteration 4; 372/469 	Minibatch Loss 0.099  Accuracy 97%
Iteration 4; 373/469 	Minibatch Loss 0.062  Accuracy 98%
Iteration 4; 374/469 	Minibatch Loss 0.059  Accuracy 98%
Iteration 4; 375/469 	Minibatch Loss 0.145  Accuracy 97%
Iteration 4; 376/469 	Minibatch Loss 0.031  Accuracy 100%
Iteration 4; 377/469 	Minibatc

Iteration 5; 36/469 	Minibatch Loss 0.047  Accuracy 98%
Iteration 5; 37/469 	Minibatch Loss 0.054  Accuracy 98%
Iteration 5; 38/469 	Minibatch Loss 0.081  Accuracy 98%
Iteration 5; 39/469 	Minibatch Loss 0.043  Accuracy 98%
Iteration 5; 40/469 	Minibatch Loss 0.075  Accuracy 98%
Iteration 5; 41/469 	Minibatch Loss 0.049  Accuracy 98%
Iteration 5; 42/469 	Minibatch Loss 0.104  Accuracy 96%
Iteration 5; 43/469 	Minibatch Loss 0.037  Accuracy 99%
Iteration 5; 44/469 	Minibatch Loss 0.060  Accuracy 98%
Iteration 5; 45/469 	Minibatch Loss 0.098  Accuracy 98%
Iteration 5; 46/469 	Minibatch Loss 0.073  Accuracy 98%
Iteration 5; 47/469 	Minibatch Loss 0.099  Accuracy 95%
Iteration 5; 48/469 	Minibatch Loss 0.144  Accuracy 97%
Iteration 5; 49/469 	Minibatch Loss 0.057  Accuracy 99%
Iteration 5; 50/469 	Minibatch Loss 0.076  Accuracy 98%
Iteration 5; 51/469 	Minibatch Loss 0.099  Accuracy 95%
Iteration 5; 52/469 	Minibatch Loss 0.106  Accuracy 99%
Iteration 5; 53/469 	Minibatch Loss 0.118  Accur

Iteration 5; 193/469 	Minibatch Loss 0.078  Accuracy 98%
Iteration 5; 194/469 	Minibatch Loss 0.072  Accuracy 97%
Iteration 5; 195/469 	Minibatch Loss 0.056  Accuracy 97%
Iteration 5; 196/469 	Minibatch Loss 0.111  Accuracy 95%
Iteration 5; 197/469 	Minibatch Loss 0.071  Accuracy 98%
Iteration 5; 198/469 	Minibatch Loss 0.052  Accuracy 98%
Iteration 5; 199/469 	Minibatch Loss 0.072  Accuracy 97%
Iteration 5; 200/469 	Minibatch Loss 0.148  Accuracy 96%
Iteration 5; 201/469 	Minibatch Loss 0.085  Accuracy 98%
Iteration 5; 202/469 	Minibatch Loss 0.109  Accuracy 97%
Iteration 5; 203/469 	Minibatch Loss 0.041  Accuracy 98%
Iteration 5; 204/469 	Minibatch Loss 0.039  Accuracy 98%
Iteration 5; 205/469 	Minibatch Loss 0.037  Accuracy 99%
Iteration 5; 206/469 	Minibatch Loss 0.082  Accuracy 98%
Iteration 5; 207/469 	Minibatch Loss 0.053  Accuracy 98%
Iteration 5; 208/469 	Minibatch Loss 0.173  Accuracy 97%
Iteration 5; 209/469 	Minibatch Loss 0.069  Accuracy 98%
Iteration 5; 210/469 	Minibatch

Iteration 5; 347/469 	Minibatch Loss 0.077  Accuracy 98%
Iteration 5; 348/469 	Minibatch Loss 0.088  Accuracy 98%
Iteration 5; 349/469 	Minibatch Loss 0.090  Accuracy 97%
Iteration 5; 350/469 	Minibatch Loss 0.053  Accuracy 99%
Iteration 5; 351/469 	Minibatch Loss 0.050  Accuracy 98%
Iteration 5; 352/469 	Minibatch Loss 0.066  Accuracy 98%
Iteration 5; 353/469 	Minibatch Loss 0.095  Accuracy 95%
Iteration 5; 354/469 	Minibatch Loss 0.039  Accuracy 98%
Iteration 5; 355/469 	Minibatch Loss 0.075  Accuracy 98%
Iteration 5; 356/469 	Minibatch Loss 0.042  Accuracy 98%
Iteration 5; 357/469 	Minibatch Loss 0.055  Accuracy 98%
Iteration 5; 358/469 	Minibatch Loss 0.104  Accuracy 98%
Iteration 5; 359/469 	Minibatch Loss 0.036  Accuracy 100%
Iteration 5; 360/469 	Minibatch Loss 0.056  Accuracy 99%
Iteration 5; 361/469 	Minibatch Loss 0.111  Accuracy 96%
Iteration 5; 362/469 	Minibatch Loss 0.095  Accuracy 96%
Iteration 5; 363/469 	Minibatch Loss 0.169  Accuracy 95%
Iteration 5; 364/469 	Minibatc

In [7]:
#predict in distribution
MNIST_PATH = "models/mnist_test_6iter_10c_simpleNN_100.pth"

mnist_model = LPADirNN(x=100)
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

for batch_idx, (x, y) in enumerate(mnist_test_loader):
        max_len = int(np.ceil(len(mnist_test_loader.dataset)/BATCH_SIZE_TEST_MNIST))
        output = mnist_model(x)
        
        accuracy = get_accuracy(output, y)
        if batch_idx % 10 == 0:
            print(
                "Batch {}/{} \t".format(batch_idx, max_len) + 
                "Accuracy %.0f" % (accuracy * 100) + "%"
            )
        acc.append(accuracy)
    
avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: models/mnist_test_6iter_10c_simpleNN_100.pth
Batch 0/79 	Accuracy 99%
Batch 10/79 	Accuracy 95%
Batch 20/79 	Accuracy 97%
Batch 30/79 	Accuracy 95%
Batch 40/79 	Accuracy 98%
Batch 50/79 	Accuracy 97%
Batch 60/79 	Accuracy 99%
Batch 70/79 	Accuracy 95%
overall test accuracy on MNIST: 97.09 %


In [8]:
## play around with Backpack
lossfunc = torch.nn.CrossEntropyLoss()

extend(lossfunc, debug=False)
extend(mnist_model, debug=False)

device = 'cpu'

Hessian = []
for param in mnist_model.parameters():
    ps = param.size()
    print("parameter size: ", ps)
    Hessian.append(torch.zeros(ps, device=device))
    #print(param.numel())
    
batch_size = BATCH_SIZE_TRAIN_MNIST

tau = 1/10
    
with backpack(DiagHessian()):
    
    max_len = int(np.ceil(len(mnist_train_loader.dataset)/batch_size))
    
    for batch_idx, (x, y) in enumerate(mnist_train_loader):

        if device == 'cuda':
            x, y = torch.from_numpy(x).float().cuda(), torch.from_numpy(y).long().cuda()

        mnist_model.zero_grad()
        lossfunc(mnist_model(x), y).backward()

        with torch.no_grad():
            # Hessian of weight
            for idx, param in enumerate(mnist_model.parameters()):

                H_ = param.diag_h
                #print(H_.size())
                #H_ = np.sqrt(batch_size)*H_ + np.sqrt(tau)*torch.eye(H_.size(0), device=device)
                
                rho = min(1-1/(batch_idx+1), 0.95)

                Hessian[idx] = rho*Hessian[idx] + (1-rho)*H_

        print("Batch: {}/{}".format(batch_idx, max_len))
        
    #combine all elements of the Hessian to one big vector
    Hessian = torch.cat([el.view(-1) for el in Hessian])
    print("Hessian_size: ", Hessian.size())
    num_params = np.sum([p.numel() for p in mnist_model.parameters()])
    assert(num_params == Hessian.size(-1))
        

parameter size:  torch.Size([100, 784])
parameter size:  torch.Size([100])
parameter size:  torch.Size([100, 100])
parameter size:  torch.Size([100])
parameter size:  torch.Size([10, 100])
parameter size:  torch.Size([10])
Batch: 0/469
Batch: 1/469
Batch: 2/469
Batch: 3/469
Batch: 4/469
Batch: 5/469
Batch: 6/469
Batch: 7/469
Batch: 8/469
Batch: 9/469
Batch: 10/469
Batch: 11/469
Batch: 12/469
Batch: 13/469
Batch: 14/469
Batch: 15/469
Batch: 16/469
Batch: 17/469
Batch: 18/469
Batch: 19/469
Batch: 20/469
Batch: 21/469
Batch: 22/469
Batch: 23/469
Batch: 24/469
Batch: 25/469
Batch: 26/469
Batch: 27/469
Batch: 28/469
Batch: 29/469
Batch: 30/469
Batch: 31/469
Batch: 32/469
Batch: 33/469
Batch: 34/469
Batch: 35/469
Batch: 36/469
Batch: 37/469
Batch: 38/469
Batch: 39/469
Batch: 40/469
Batch: 41/469
Batch: 42/469
Batch: 43/469
Batch: 44/469
Batch: 45/469
Batch: 46/469
Batch: 47/469
Batch: 48/469
Batch: 49/469
Batch: 50/469
Batch: 51/469
Batch: 52/469
Batch: 53/469
Batch: 54/469
Batch: 55/469
Bat

In [9]:
def get_Jacobian(model, x):
    
    outputs = model(x)
    batch_size = outputs.shape[0]
    num_classes = outputs.shape[1]
    j_size = np.sum([p.numel() for p in model.parameters()])
    Jacobian = torch.zeros(batch_size, num_classes, j_size)
    
    for i in range(batch_size):
        for j in range(num_classes):
            #Variable(outputs[i,j], requires_grad=True).backward(retain_variables=True)
            outputs[i,j].backward(retain_graph=True)
            single_jacobian = []
            for k, p in enumerate(model.parameters()):
                g = p.grad.clone() # now contains the derivative of the o[i,j] w.r.t. weights
                single_jacobian.append(g)
            
            model.zero_grad()
                
            single_jacobian = torch.cat([el.view(-1) for el in single_jacobian])
            Jacobian[i,j] = single_jacobian
    
    return(Jacobian)

In [11]:
"""
for batch_idx, (x, y) in enumerate(mnist_train_loader):
    
    J = get_Jacobian(mnist_model, x)
    print("J size: ", J.size())
    batch_size = J.size(0)
    num_classes = J.size(1)
    cov = torch.zeros(batch_size, num_classes, num_classes)
    for i in range(batch_size):
        cov[i] = torch.mul(J[i], Hessian) @ J[i].T
    print(cov.size())
    #print(cov)
"""

'\nfor batch_idx, (x, y) in enumerate(mnist_train_loader):\n    \n    J = get_Jacobian(mnist_model, x)\n    print("J size: ", J.size())\n    batch_size = J.size(0)\n    num_classes = J.size(1)\n    cov = torch.zeros(batch_size, num_classes, num_classes)\n    for i in range(batch_size):\n        cov[i] = torch.mul(J[i], Hessian) @ J[i].T\n    print(cov.size())\n    #print(cov)\n'

In [12]:
print(Hessian)

tensor([0.0000, 0.0000, 0.0000,  ..., 0.0043, 0.0063, 0.0066])


In [13]:
def predict_Diagonal_full(model, test_loader, verbose=True, num_samples=100):
    
    py = []
    
    max_len = int(np.ceil(len(test_loader.dataset)/len(test_loader)))
    for batch_idx, (x, y) in enumerate(test_loader):
        
        J = get_Jacobian(model, x)
        batch_size = J.size(0)
        num_classes = J.size(1)
        Cov_pred = torch.zeros(batch_size, num_classes, num_classes)
        for i in range(batch_size):
            Cov_pred[i] = torch.mul(J[i], Hessian) @ J[i].T    #inverting the Hessian give worse results for some reason

        mu_pred = model(x)
        #print(Cov_pred)
        post_pred = MultivariateNormal(mu_pred, Cov_pred)

        # MC-integral
        py_ = 0

        for _ in range(num_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)


        py_ /= num_samples
        py_ = py_.detach()

        py.append(py_)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))

    return torch.cat(py, dim=0)

In [14]:
BATCH_SIZE_TEST_FMNIST = 128
BATCH_SIZE_TRAIN_FMNIST = 32

In [15]:
FMNIST_dataset = torchvision.datasets.FashionMNIST(
        './fmnist', train=True, download=True,
        transform=MNIST_transform) #torchvision.transforms.ToTensor())

train_size = int(0.8 * len(FMNIST_dataset))
val_size = len(FMNIST_dataset) - train_size
FMNIST_train_dataset, FMNIST_val_dataset = torch.utils.data.random_split(FMNIST_dataset, [train_size, val_size])

FMNIST_train_loader = torch.utils.data.DataLoader(
    FMNIST_train_dataset,
    batch_size=BATCH_SIZE_TRAIN_FMNIST, shuffle=True)

FMNIST_val_loader = torch.utils.data.DataLoader(
    FMNIST_val_dataset,
    batch_size=BATCH_SIZE_TRAIN_FMNIST, shuffle=False)


FMNIST_test = torchvision.datasets.FashionMNIST(
        './fmnist', train=False, download=False,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [16]:
KMNIST_test = torchvision.datasets.KMNIST(
        './kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=128, shuffle=False)

In [17]:
"""Load notMNIST"""

import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from matplotlib.pyplot import imread
from torch import Tensor

"""
Loads the train/test set. 
Every image in the dataset is 28x28 pixels and the labels are numbered from 0-9
for A-J respectively.
Set root to point to the Train/Test folders.
"""

# Creating a sub class of torch.utils.data.dataset.Dataset
class notMNIST(Dataset):

    # The init method is called when this class will be instantiated
    def __init__(self, root, transform):
        
        #super(notMNIST, self).__init__(root, transform=transform)

        self.transform = transform
        
        Images, Y = [], []
        folders = os.listdir(root)

        for folder in folders:
            folder_path = os.path.join(root, folder)
            for ims in os.listdir(folder_path):
                try:
                    img_path = os.path.join(folder_path, ims)
                    Images.append(np.array(imread(img_path)))
                    Y.append(ord(folder) - 65)  # Folders are A-J so labels will be 0-9
                except:
                    # Some images in the dataset are damaged
                    print("File {}/{} is broken".format(folder, ims))
        data = [(x, y) for x, y in zip(Images, Y)]
        self.data = data
        self.targets = torch.Tensor(Y)

    # The number of items in the dataset
    def __len__(self):
        return len(self.data)

    # The Dataloader is a generator that repeatedly calls the getitem method.
    # getitem is supposed to return (X, Y) for the specified index.
    def __getitem__(self, index):
        img = self.data[index][0]

        if self.transform is not None:
            img = self.transform(img)
            
        # Input for Conv2D should be Channels x Height x Width
        img_tensor = Tensor(img).view(1, 28, 28).float()
        label = self.data[index][1]
        return (img_tensor, label)

In [18]:
root = os.path.abspath('')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
not_mnist_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=128,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


In [19]:
def get_in_dist_values(py_in, targets):
    acc_in = np.mean(np.argmax(py_in, 1) == targets)
    prob_correct = np.choose(targets, py_in.T).mean()
    average_entropy = -np.sum(py_in*np.log(py_in+1e-8), axis=1).mean()
    MMC = py_in.max(1).mean()
    return(acc_in, prob_correct, average_entropy, MMC)
    
def get_out_dist_values(py_in, py_out, targets):
    average_entropy = -np.sum(py_out*np.log(py_out+1e-8), axis=1).mean()
    acc_out = np.mean(np.argmax(py_out, 1) == targets)
    prob_correct = np.choose(targets, py_out.T).mean()
    labels = np.zeros(len(py_in)+len(py_out), dtype='int32')
    labels[:len(py_in)] = 1
    examples = np.concatenate([py_in.max(1), py_out.max(1)])
    auroc = roc_auc_score(labels, examples)
    MMC = py_out.max(1).mean()
    return(acc_out, prob_correct, average_entropy, MMC, auroc)

def print_in_dist_values(acc_in, prob_correct, average_entropy, MMC, train='mnist', method='LLLA-KF'):
    
    print(f'[In, {method}, {train}] Accuracy: {acc_in:.3f}; average entropy: {average_entropy:.3f}; \
    MMC: {MMC:.3f}; Prob @ correct: {prob_correct:.3f}')


def print_out_dist_values(acc_out, prob_correct, average_entropy, MMC, auroc, train='mnist', test='FMNIST', method='LLLA-KF'):
   
    print(f'[Out-{test}, {method}, {train}] Accuracy: {acc_out:.3f}; Average entropy: {average_entropy:.3f};\
    MMC: {MMC:.3f}; AUROC: {auroc:.3f}; Prob @ correct: {prob_correct:.3f}')

# MAP estimate

In [20]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [21]:
mnist_test_in_MAP = predict_MAP(mnist_model, mnist_test_loader).numpy()
mnist_test_out_fmnist_MAP = predict_MAP(mnist_model, FMNIST_test_loader).numpy()
mnist_test_out_notMNIST_MAP = predict_MAP(mnist_model, not_mnist_test_loader).numpy()
mnist_test_out_KMNIST_MAP = predict_MAP(mnist_model, KMNIST_test_loader).numpy()

In [22]:
acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP = get_in_dist_values(mnist_test_in_MAP, targets)
acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_fmnist_MAP, targets_FMNIST)
acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_notMNIST_MAP, targets_notMNIST)
acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_KMNIST_MAP, targets_KMNIST)

In [23]:
print_in_dist_values(acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP, 'mnist', 'MAP')
print_out_dist_values(acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP, 'FMNIST', 'MAP')
print_out_dist_values(acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP, 'notMNIST', 'MAP')
print_out_dist_values(acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP, 'KMNIST', 'MAP')

[In, MAP, mnist] Accuracy: 0.971; average entropy: 0.105;     MMC: 0.968; Prob @ correct: 0.954
[Out-MAP, LLLA-KF, FMNIST] Accuracy: 0.056; Average entropy: 0.506;    MMC: 0.824; AUROC: 0.832; Prob @ correct: 0.059
[Out-MAP, LLLA-KF, notMNIST] Accuracy: 0.121; Average entropy: 0.548;    MMC: 0.799; AUROC: 0.849; Prob @ correct: 0.126
[Out-MAP, LLLA-KF, KMNIST] Accuracy: 0.102; Average entropy: 0.750;    MMC: 0.728; AUROC: 0.908; Prob @ correct: 0.100


# Diag Hessian Sampling estimate

In [24]:
mnist_test_in_D = predict_Diagonal_full(mnist_model, mnist_test_loader, num_samples=100).numpy()
mnist_test_out_FMNIST_D = predict_Diagonal_full(mnist_model, FMNIST_test_loader, num_samples=100).numpy()
mnist_test_out_notMNIST_D = predict_Diagonal_full(mnist_model, not_mnist_test_loader, num_samples=100).numpy()
mnist_test_out_KMNIST_D = predict_Diagonal_full(mnist_model, KMNIST_test_loader, num_samples=100).numpy()

Batch: 0/127
Batch: 1/127
Batch: 2/127
Batch: 3/127
Batch: 4/127
Batch: 5/127
Batch: 6/127
Batch: 7/127
Batch: 8/127
Batch: 9/127
Batch: 10/127
Batch: 11/127
Batch: 12/127
Batch: 13/127
Batch: 14/127
Batch: 15/127
Batch: 16/127
Batch: 17/127
Batch: 18/127
Batch: 19/127
Batch: 20/127
Batch: 21/127
Batch: 22/127
Batch: 23/127
Batch: 24/127
Batch: 25/127
Batch: 26/127
Batch: 27/127
Batch: 28/127
Batch: 29/127
Batch: 30/127
Batch: 31/127
Batch: 32/127
Batch: 33/127
Batch: 34/127
Batch: 35/127
Batch: 36/127
Batch: 37/127
Batch: 38/127
Batch: 39/127
Batch: 40/127
Batch: 41/127
Batch: 42/127
Batch: 43/127
Batch: 44/127
Batch: 45/127
Batch: 46/127
Batch: 47/127
Batch: 48/127
Batch: 49/127
Batch: 50/127
Batch: 51/127
Batch: 52/127
Batch: 53/127
Batch: 54/127
Batch: 55/127
Batch: 56/127
Batch: 57/127
Batch: 58/127
Batch: 59/127
Batch: 60/127
Batch: 61/127
Batch: 62/127
Batch: 63/127
Batch: 64/127
Batch: 65/127
Batch: 66/127
Batch: 67/127
Batch: 68/127
Batch: 69/127
Batch: 70/127
Batch: 71/127
Ba

In [25]:
acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D = get_in_dist_values(mnist_test_in_D, targets)
acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_FMNIST_D, targets_FMNIST)
acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_notMNIST_D, targets_notMNIST)
acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_KMNIST_D, targets_KMNIST)

In [26]:
print_in_dist_values(acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D, 'mnist', 'Diag')
print_out_dist_values(acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D, test='fmnist', method='Diag')
print_out_dist_values(acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D, test='notMNIST', method='Diag')
print_out_dist_values(acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D, test='KMNIST', method='Diag')

[In, Diag, mnist] Accuracy: 0.971; average entropy: 0.427;     MMC: 0.876; Prob @ correct: 0.867
[Out-fmnist, Diag, mnist] Accuracy: 0.058; Average entropy: 1.042;    MMC: 0.625; AUROC: 0.876; Prob @ correct: 0.069
[Out-notMNIST, Diag, mnist] Accuracy: 0.120; Average entropy: 1.028;    MMC: 0.617; AUROC: 0.881; Prob @ correct: 0.112
[Out-KMNIST, Diag, mnist] Accuracy: 0.102; Average entropy: 1.145;    MMC: 0.572; AUROC: 0.904; Prob @ correct: 0.101
