In [1]:
import math
import sys
sys.path.append('/Users/philippvonbachmann/Documents/University/WiSe2122/ResearchProject/ResearchProjectLLVI/BasicExample')
from torch import nn, optim
from torch.nn import functional as F
import torch
import torchvision
from laplace import Laplace
from tqdm import tqdm
from matplotlib import pyplot as plt
import matplotlib

from src.weight_distribution.Full import FullCovariance
from src.weight_distribution.Diagonal import Diagonal
from src.network.Classification import LLVIClassification
from src.network import PredictApprox, LikApprox


In [2]:
class CNN(nn.Module):
    def __init__(self, out_dim=10, optimizer=optim.Adam, **optim_kwargs):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 50)
        self.fc3 = nn.Linear(50, out_dim, bias=False)
        self.optimizer: optim = optimizer(self.parameters(), **optim_kwargs)
        self.nonll = torch.sigmoid # nonlinear layer

    def forward(self, x):
        x = self.pool(self.nonll(self.conv1(x)))
        x = self.pool(self.nonll(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.nonll(self.fc1(x))
        x = self.nonll(self.fc2(x))
        x = self.fc3(x)
        return x 


class VICNN(nn.Module):
    def __init__(self, feature_dim=50, optimizer=optim.Adam, **optim_kwargs):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, feature_dim)
        self.optimizer: optim = optimizer(self.parameters(), **optim_kwargs)
        self.nonll = torch.sigmoid # nonlinear layer

    def forward(self, x):
        x = self.pool(self.nonll(self.conv1(x)))
        x = self.pool(self.nonll(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.nonll(self.fc1(x))
        x = self.nonll(self.fc2(x))
        return x

In [3]:
batch_size_train = 32
batch_size_test = 60000
filepath = "/Users/philippvonbachmann/Documents/University/WiSe2122/ResearchProject/ResearchProjectLLVI/BasicExample/datasets/Classification"
# create dataset
dataset = torchvision.datasets.MNIST(filepath, train=True, download=False,
                            transform=torchvision.transforms.Compose([
                              torchvision.transforms.ToTensor(),
                              torchvision.transforms.Normalize(
                                (0.1307,), (0.3081,))
                            ]))
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size_test, shuffle=True)
n_datapoints = batch_size_train * len(train_loader)


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
# general parameters
feature_dim = 50
out_dim=10
weight_decay = 1e-3
lr = 1e-3
base_train_epochs = 10

# VI parameters
tau = 1
vi_train_epochs = 10

In [5]:
torch.manual_seed(3)
laplace_model = CNN(weight_decay=weight_decay, lr=lr, out_dim=out_dim)
criterion = torch.nn.CrossEntropyLoss()

# train
pbar = tqdm(range(base_train_epochs))
for i in pbar:
    epoch_loss = 0
    for X_batch, y_batch in train_loader:
        laplace_model.optimizer.zero_grad()
        output = laplace_model(X_batch)
        output = F.log_softmax(output, dim=-1)
        loss = criterion(output, y_batch)
        loss.backward()
        epoch_loss += loss.item()
        laplace_model.optimizer.step()

    pbar.set_description(f"Loss: {round(float(torch.mean(loss)), 2)}")

Loss: 0.34: 100%|██████████| 10/10 [02:09<00:00, 12.94s/it]


In [6]:
# define laplace
la = Laplace(laplace_model, "classification",
    subset_of_weights="last_layer", hessian_structure="diag",
    prior_precision=weight_decay) # prior precision is set to wdecay
la.fit(train_loader)
for X_batch, y_batch in test_loader:
    predictions = la(X_batch, link_approx='mc', n_samples=1000)
    pred_test = torch.argmax(predictions, dim=1)
    print("Accuracy with Laplace", torch.mean((pred_test == y_batch).float()).item())

Accuracy with Laplace 0.9710000157356262


In [14]:
dist = Diagonal(50, 10, lr=lr)
dist.update_var(torch.reshape(la.posterior_variance, (50, 10)))
dist.update_mean(torch.t(laplace_model.fc3.weight))

# dist = FullCovariance(50, 10, lr=lr)
# dist.update_cov(la.posterior_covariance)
# dist.update_mean(torch.t(laplace_model.fc3.weight))


vi_model = VICNN(weight_decay=weight_decay, lr=lr, feature_dim=feature_dim)
with torch.no_grad():
    vi_model.load_state_dict(laplace_model.state_dict(), strict=False)


prior_log_var = math.log(1/(weight_decay * n_datapoints))
prior_log_var = -7
net = LLVIClassification(50, 10, vi_model, dist, prior_log_var=prior_log_var, optimizer_type=torch.optim.Adam,
tau=tau, lr=lr)

for X_batch, y_batch in test_loader:
    predictions = net(X_batch, method=PredictApprox.MONTECARLO, samples=100)
    pred_test = torch.argmax(predictions, dim=1)
    print("Accuracy with VI before Training", torch.mean((pred_test == y_batch).float()).item())

net.train_model(train_loader, epochs=vi_train_epochs, n_datapoints=n_datapoints, samples=10, method=LikApprox.MONTECARLO)

for X_batch, y_batch in test_loader:
    predictions = net(X_batch, method=PredictApprox.MONTECARLO, samples=100)
    pred_test = torch.argmax(predictions, dim=1)
    print("Accuracy with VI after Training", torch.mean((pred_test == y_batch).float()).item())

Accuracy with VI before Training 0.9726999998092651


prediction_loss:0.15 kl_loss:0.73: 100%|██████████| 10/10 [02:37<00:00, 15.75s/it]


Accuracy with VI after Training 0.972100019454956


In [15]:
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(filepath, train=False, download=False,
                            transform=torchvision.transforms.Compose([
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize(
                                (0.1307,), (0.3081,))
                            ])),
batch_size=batch_size_test, shuffle=True)

ood_test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(filepath, train=False, download=True,
                            transform=torchvision.transforms.Compose([
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                torchvision.transforms.RandomHorizontalFlip(p=1)
                            ])),
batch_size=batch_size_test, shuffle=True)

In [16]:
def test_confidence(predict_fun, test_loader, ood_test_loader):
    confidence_batch = []
    with torch.no_grad():
        for data, target in test_loader:
            output = predict_fun(data)
            pred, _ = torch.max(output, dim=1) # confidence in choice
            confidence_batch.append(torch.mean(pred))
        print(f"The mean confidence for in distribution data is: {sum(confidence_batch)/len(confidence_batch)}")

    ood_confidence_batch = []
    with torch.no_grad():
        for data, target in ood_test_loader:
            output = predict_fun(data)
            pred, _ = torch.max(output, dim=1) # confidence in choice
            ood_confidence_batch.append(torch.mean(pred))
        print(f"The mean confidence for out-of distribution data is: {sum(ood_confidence_batch)/len(ood_confidence_batch)}")

In [17]:
predict_samples = 100
la_predict_fun = lambda x: la(x, link_approx='mc', n_samples=predict_samples)
vi_predict_fun = lambda x: net(x, method=PredictApprox.MONTECARLO, samples=100)

In [18]:
test_confidence(la_predict_fun, test_loader, ood_test_loader)

The mean confidence for in distribution data is: 0.9315642714500427
The mean confidence for out-of distribution data is: 0.7018277049064636


In [19]:
test_confidence(vi_predict_fun, test_loader, ood_test_loader)

The mean confidence for in distribution data is: 0.9121099710464478
The mean confidence for out-of distribution data is: 0.6925305724143982
