In [1]:
import math
import sys
sys.path.append('/Users/philippvonbachmann/Documents/University/WiSe2122/ResearchProject/ResearchProjectLLVI/BasicExample')
sys.path.append('/Users/philippvonbachmann/Documents/University/WiSe2122/ResearchProject/ResearchProjectLLVI')
from torch import nn, optim
from torch.nn import functional as F
import torch
import torchvision
from laplace import Laplace
from tqdm import tqdm
from matplotlib import pyplot as plt
import matplotlib

from src.weight_distribution.Full import FullCovariance
from src.weight_distribution.Diagonal import Diagonal
from src.network.Classification import LLVIClassification
from src.network import PredictApprox, LikApprox
from PyTorch_CIFAR10.cifar10_models.resnet import resnet18, resnet34, resnet18VI


In [2]:
batch_size_train = 32
dataset_size = 1024
data_dir = "/Users/philippvonbachmann/Documents/University/WiSe2122/ResearchProject/ResearchProjectLLVI/BasicExample/datasets/Classification/CIFAR10"
dataset = torchvision.datasets.CIFAR10(data_dir, train=True, download=False,
                            transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
            ]
        ))
dataset, _  = torch.utils.data.random_split(dataset, [dataset_size, len(dataset) - dataset_size])
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size_train, shuffle=True)
n_datapoints = batch_size_train * len(train_loader)


In [3]:
batch_size_test = 1000 # whole dataset
test_dataset_size = 1000
data_dir = "/Users/philippvonbachmann/Documents/University/WiSe2122/ResearchProject/ResearchProjectLLVI/BasicExample/datasets/Classification/CIFAR10"
test_dataset = torchvision.datasets.CIFAR10(data_dir, train=False, download=False,
                            transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
            ]
        ))
test_dataset, _  = torch.utils.data.random_split(test_dataset, [test_dataset_size, len(test_dataset) - test_dataset_size])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=True)
n_test_datapoints = batch_size_test * len(test_loader)

In [14]:
# general parameters
feature_dim = 50
out_dim=10
weight_decay = 1e-3
lr = 1e-3
base_train_epochs = 1

# VI parameters
tau = 1
vi_train_epochs = 2

In [5]:
laplace_model = resnet18(pretrained=True)

In [6]:
for X_batch, y_batch in test_loader:
    predictions = laplace_model(X_batch)
    pred_test = torch.argmax(predictions, dim=1)
    print("MAP accuracy", torch.mean((pred_test == y_batch).float()).item())

MAP accuracy 0.9229999780654907


In [7]:
# define laplace
la = Laplace(laplace_model, "classification",
    subset_of_weights="last_layer", hessian_structure="diag",
    prior_precision=1) # prior precision is set to wdecay
la.fit(train_loader)

In [8]:
for X_batch, y_batch in test_loader:
    predictions = la(X_batch, link_approx='mc', n_samples=100)
    pred_test = torch.argmax(predictions, dim=1)
    print("Accuracy with Laplace", torch.mean((pred_test == y_batch).float()).item())

Accuracy with Laplace 0.9079999923706055


In [15]:
dist = Diagonal(512, 10, lr=lr)
# dist.update_var(torch.reshape(la.posterior_variance[:-10], (512, 10)))
dist.update_mean(torch.t(laplace_model.fc.weight))


vi_model = resnet18VI(pretrained=True)

prior_log_var = 1
net = LLVIClassification(512, 10, vi_model, dist, prior_log_var=prior_log_var, optimizer_type=torch.optim.Adam,
tau=tau, lr=lr)

for X_batch, y_batch in test_loader:
    predictions = net(X_batch, method=PredictApprox.MONTECARLO, samples=100)
    pred_test = torch.argmax(predictions, dim=1)
    print("Accuracy with VI before Training", torch.mean((pred_test == y_batch).float()).item())

net.train_ll_only(train_loader, epochs=vi_train_epochs, n_datapoints=n_datapoints, samples=10, method=LikApprox.MONTECARLO)

for X_batch, y_batch in test_loader:
    predictions = net(X_batch, method=PredictApprox.MONTECARLO, samples=100)
    pred_test = torch.argmax(predictions, dim=1)
    print("Accuracy with VI after Training", torch.mean((pred_test == y_batch).float()).item())

Accuracy with VI before Training 0.902999997138977


prediction_loss:1.38 kl_loss:1.54: 100%|██████████| 2/2 [00:49<00:00, 24.90s/it]


Accuracy with VI after Training 0.9100000262260437


In [None]:
def test_confidence(predict_fun, test_loader, ood_test_loader):
    confidence_batch = []
    with torch.no_grad():
        for data, target in test_loader:
            output = predict_fun(data)
            pred, _ = torch.max(output, dim=1) # confidence in choice
            confidence_batch.append(torch.mean(pred))
        print(f"The mean confidence for in distribution data is: {sum(confidence_batch)/len(confidence_batch)}")

    ood_confidence_batch = []
    with torch.no_grad():
        for data, target in ood_test_loader:
            output = predict_fun(data)
            pred, _ = torch.max(output, dim=1) # confidence in choice
            ood_confidence_batch.append(torch.mean(pred))
        print(f"The mean confidence for out-of distribution data is: {sum(ood_confidence_batch)/len(ood_confidence_batch)}")

In [None]:
predict_samples = 100
la_predict_fun = lambda x: la(x, link_approx='mc', n_samples=predict_samples)
vi_predict_fun = lambda x: net(x, method=PredictApprox.MONTECARLO, samples=100)

In [None]:
test_confidence(la_predict_fun, test_loader, ood_test_loader)

The mean confidence for in distribution data is: 0.8469984531402588
The mean confidence for out-of distribution data is: 0.6441346406936646


In [None]:
test_confidence(vi_predict_fun, test_loader, ood_test_loader)

The mean confidence for in distribution data is: 0.8522112369537354
The mean confidence for out-of distribution data is: 0.6150723695755005
