# Predictive Uncertainty in Active Deep Learning

How much is precise *predictive uncertainty* (i.e. model confidence) important for *active deep learning* (ADA)?

In [None]:
from copy import deepcopy
import os
import random

from matplotlib import pyplot as plt
import numpy as np
from scipy.special import entr
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import euclidean_distances
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms

In [None]:
BS = 2048
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODELPATH = "models/{}.pt"
N_CLASSES = 10
DEVICE

In [None]:
svhn_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)])
svhn_trainset = datasets.SVHN(root="data", split="train", transform=svhn_transforms)
svhn_testset = datasets.SVHN(root="data", split="test", transform=svhn_transforms)

In [None]:
mnist_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)])
mnist_trainset = datasets.MNIST(root="data", train=True, transform=mnist_transforms)
mnist_testset = datasets.MNIST(root="data", train=False, transform=mnist_transforms)

In [None]:
class LeNet(nn.Module):
    """See Hoffman et al. (2017)."""
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(20, 50, 5),
            nn.Dropout2d(p=0.5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(50 * 4 * 4, 500)).to(DEVICE)
        self.classifier = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(500, N_CLASSES)).to(DEVICE)

    def forward(self, X):
        return self.classifier(self.feature_extractor(X))

    @torch.no_grad()
    def predict(self, dataset):
        self.eval()
        dataloader = DataLoader(dataset, batch_size=BS)
        return torch.concat([self(batch[0].to(DEVICE)) for batch in dataloader]).cpu()

    def scores(self, logits):
        return F.softmax(logits, dim=-1)

    @torch.no_grad()
    def test(self, testset):
        y_pred = torch.argmax(self.scores(self.predict(testset)), dim=-1)
        y = torch.concat([batch[-1] for batch in DataLoader(testset, batch_size=BS)])
        return torch.mean((y_pred == y).type(torch.float)).item()

    def train_epoch(self, trainloader, optimiser):
        self.train()
        for X_batch, y_batch in trainloader:
            optimiser.zero_grad()
            loss = F.cross_entropy(self(X_batch.to(DEVICE)), y_batch.to(DEVICE))
            loss.backward()
            optimiser.step()
        return self

    def train_epochs(self, trainset, testset, hyperparams, modelname):
        optimiser = hyperparams["optimiser"](self.parameters(), lr=hyperparams["lr"], weight_decay=hyperparams["wd"])
        trainloader = DataLoader(trainset, batch_size=hyperparams["bs"], shuffle=True)
        accuracies = np.zeros(hyperparams["n_epochs"] + 1)
        accuracies[0] = self.test(testset)
        print(0, accuracies[0])
        for epoch in range(1, hyperparams["n_epochs"] + 1):
            self.train_epoch(trainloader, optimiser)
            accuracies[epoch] = self.test(testset)
            print(epoch, accuracies[epoch])
        torch.save(self.state_dict(), MODELPATH.format(modelname))
        fig, ax = plt.subplots()
        ax.scatter(np.arange(hyperparams["n_epochs"] + 1), accuracies)
        return self
    
    def load(self, modelname):
        self.load_state_dict(torch.load(MODELPATH.format(modelname)))
        return self

In [None]:
HYPERPARAMS = {
    "bs": 128,
    "n_epochs": 60,
    "optimiser": optim.Adam,
    "lr": 2e-4,
    "wd": 1e-5}

In [None]:
# target only: aim for 99.2 ± 0.1
lenet_target = LeNet()
#lenet_target.train_epochs(mnist_trainset, mnist_testset, HYPERPARAMS, "target")
lenet_target.load("target")
lenet_target.test(mnist_testset)

In [None]:
# source only: aim for 67.1 ± 0.6
lenet_source = LeNet()
#lenet_source.train_epochs(svhn_trainset, mnist_testset, HYPERPARAMS, "source")
lenet_source.load("source")
lenet_source.test(mnist_testset)

In [None]:
def ada(model, target_trainset, target_testset, params):
    idx_unannotated = set(range(len(target_trainset)))
    idx_annotated = set()
    accuracies = np.zeros(params["n_rounds"] + 1)
    accuracies[0] = model.test(target_testset)
    print(0, accuracies[0])
    for r in range(1, params["n_rounds"] + 1):
        idx_query = params["strategy"](model, target_trainset, idx_unannotated, params["query_size"])
        idx_unannotated -= idx_query
        idx_annotated |= idx_query
        optimiser = params["optimiser"](model.parameters(), lr=params["lr"], weight_decay=params["wd"])
        scheduler = params["scheduler"](optimiser, step_size=params["step_size"], gamma=params["gamma"])
        target_trainsubset = Subset(target_trainset, tuple(idx_annotated))
        target_trainloader = DataLoader(target_trainsubset, batch_size=params["bs"], shuffle=True)
        for epoch in range(1, params["n_epochs"] + 1):
            model.train_epoch(target_trainloader, optimiser)
            scheduler.step()
        accuracies[r] = model.test(target_testset)
        print(r, accuracies[r])
    return accuracies

In [None]:
ADA_PARAMS = {
    "n_rounds": 30,
    "query_size": 10,
    "n_epochs": 60,
    "bs": 128,
    "optimiser": optim.Adam,
    "lr": 2e-4,
    "wd": 1e-5,
    "scheduler": optim.lr_scheduler.StepLR,
    "step_size": 20,
    "gamma": 0.5}

In [None]:
def uniform(model, target_trainset, idx_unannotated, query_size):
    return set(random.sample(tuple(idx_unannotated), query_size))

ADA_PARAMS["strategy"] = uniform
lenet_uniform = LeNet().load("source")
accuracies_uniform = ada(lenet_uniform, mnist_trainset, mnist_testset, ADA_PARAMS)
lenet_uniform.test(mnist_testset)

In [None]:
def entropy(model, target_trainset, idx_unannotated, query_size):
    idx_unannotated = np.array(tuple(idx_unannotated))
    logits = model.predict(Subset(target_trainset, idx_unannotated))
    scores = model.scores(logits)
    entropies = torch.sum(entr(scores), dim=-1)
    idx = torch.argsort(entropies, descending=True)[:query_size]
    return set(idx_unannotated[idx])

ADA_PARAMS["strategy"] = entropy
lenet_entropy = LeNet().load("source")
accuracies_entropy = ada(lenet_entropy, mnist_trainset, mnist_testset, ADA_PARAMS)
lenet_entropy.test(mnist_testset)

In [None]:
class MCLeNet(LeNet):
    """See Hoffman et al. (2017)."""
    def __init__(self, T):
        super(MCLeNet, self).__init__()
        self.T = T

    @torch.no_grad()
    def predict(self, dataset):
        self.train()
        dataloader = DataLoader(dataset, batch_size=BS)
        logits = torch.empty(self.T, len(dataset), N_CLASSES, device=DEVICE)
        for t in range(self.T):
            logits[t] = torch.concat([self(batch[0].to(DEVICE)) for batch in dataloader])
        return logits.cpu()
    
    def scores(self, logits):
        return torch.sum(F.softmax(logits, dim=-1), dim=0) / self.T

mclenet_entropy = MCLeNet(T=20)
mclenet_entropy.load("source")
mclenet_entropy.test(mnist_testset)

In [None]:
ADA_PARAMS["strategy"] = entropy
mclenet_entropy = MCLeNet(T=20).load("source")
accuracies_mcentropy = ada(mclenet_entropy, mnist_trainset, mnist_testset, ADA_PARAMS)
mclenet_entropy.test(mnist_testset)

In [None]:
def bald(model, target_trainset, idx_unannotated, query_size):
    """Bayesian active learning by disagreement (BALD)."""
    idx_unannotated = np.array(tuple(idx_unannotated))
    logits = model.predict(Subset(target_trainset, idx_unannotated))
    scores = model.scores(logits)
    balds = torch.sum(entr(scores), dim=-1) - torch.sum(entr(F.softmax(logits, dim=-1)), dim=(0, 2)) / model.T
    idx = torch.argsort(balds, descending=True)[:query_size]
    return set(idx_unannotated[idx])

ADA_PARAMS["strategy"] = bald
mclenet_bald = MCLeNet(T=20).load("source")
mclenet_bald, accuracies_bald = ada(mclenet_bald, mnist_trainset, mnist_testset, ADA_PARAMS)
mclenet_bald.test(mnist_testset)

In [None]:
class DeepEnsemble(nn.Module):
    def __init__(self, M):
        super(DeepEnsemble, self).__init__()
        self.M = M
        self.models = []
        for i, m in enumerate(range(M)):
            self.models.append(LeNet())
            self.add_module("model-" + str(i), self.models[i])

    @torch.no_grad()
    def predict(self, dataset):
        logits = torch.empty(self.M, len(dataset), N_CLASSES)
        for m, model in enumerate(self.models):
            logits[m] = model.predict(dataset)
        return logits

    def scores(self, logits):
        return torch.sum(F.softmax(logits, dim=-1), dim=0) / self.M

    @torch.no_grad()
    def test(self, testset):
        y_pred = torch.argmax(self.scores(self.predict(testset)), dim=-1)
        y = torch.concat([batch[-1] for batch in DataLoader(testset, batch_size=BS)])
        return torch.mean((y_pred == y).type(torch.float)).item()

    def train_epoch(self, trainloader, optimiser):
        for model in self.models:
            model.train_epoch(trainloader, optimiser)
        return self

    def train_epochs(self, trainset, testset, hyperparams, modelname):
        for i, model in enumerate(self.models):
            model.train_epochs(trainset, testset, hyperparams, modelname + "-" + str(i))
        return self

    def load(self, modelname):
        for i, model in enumerate(self.models):
            modelpath = MODELPATH.format(modelname + "-" + str(i))
            model.load_state_dict(torch.load(modelpath))
        return self
    
ensemble = DeepEnsemble(M=5)
#ensemble.train_epochs(svhn_trainset, mnist_testset, HYPERPARAMS, "source")
ensemble.load("source")
ensemble.test(mnist_testset)

In [None]:
ADA_PARAMS["strategy"] = entropy
ensemble_entropy = DeepEnsemble(M=5).load("source")
accuracies_ensemble_entropy = ada(ensemble_entropy, mnist_trainset, mnist_testset, ADA_PARAMS)
ensemble_entropy.test(mnist_testset)

In [None]:
fig, ax = plt.subplots()
x = np.arange(ADA_PARAMS["n_rounds"] + 1)
ax.scatter(x, acc_uniform, label="uniform")
ax.scatter(x, acc_entr, label="entropy")
ax.scatter(x[:len(acc_bald)], acc_bald, label="BALD")
#ax.scatter(x, acc_clue, label="CLUE")
ax.set_xlabel("round")
ax.set_ylabel("accuracy")
ax.legend()