This is an implementation of the paper Deep Bayesian Active Learning with Image Data using PyTorch and modAL. 

modAL is an active learning framework for Python3, designed with modularity, flexibility and extensibility in mind. Built on top of scikit-learn, it allows you to rapidly create active learning workflows with nearly complete freedom. What is more, you can easily replace parts with your custom built solutions, allowing you to design novel algorithms with ease.

Since modAL only supports sklearn models, we will also use [skorch](https://skorch.readthedocs.io/en/stable/), a scikit-learn compatible neural network library that wraps PyTorch. 

In [1]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from skorch import NeuralNetClassifier
from modAL.models import ActiveLearner

from VAE import VAE

  warn(f"Failed to load image Python extension: {e}")


### architecture of the network we will be using

We will use the architecture described in the paper.

In [2]:
# class CNN(nn.Module):
#     def __init__(self,):
#         super(CNN, self).__init__()
#         self.convs = nn.Sequential(
#                                 nn.Conv2d(1,32,4),
#                                 nn.ReLU(),
#                                 nn.Conv2d(32,32,4),
#                                 nn.ReLU(),
#                                 nn.MaxPool2d(2),
#                                 nn.Dropout(0.25)
#         )
#         self.fcs = nn.Sequential(
#                                 nn.Linear(11*11*32,128),
#                                 nn.ReLU(),
#                                 nn.Dropout(0.5),
#                                 nn.Linear(128,10),
#         )

#     def forward(self, x):
#         out = x
#         out = self.convs(out)
#         out = out.view(-1,11*11*32)
#         out = self.fcs(out)
#         return out

class lenet(nn.Module):  
    def __init__(self):
        super(lenet, self).__init__()
        self.input_height = input_height
        self.input_width = input_width
        self.input_dim = input_dim
        self.class_num = 10

        self.conv1 = nn.Conv2d(self.input_dim, 6, (5, 5), padding=2)
        self.conv2 = nn.Conv2d(6, 16, (5, 5))
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, self.class_num)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

### read training data

In [3]:
mnist_train = MNIST('.', train=True, download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
mnist_test  = MNIST('.', train=False,download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
traindataloader = DataLoader(mnist_train, shuffle=True, batch_size=60000)
testdataloader  = DataLoader(mnist_test , shuffle=True, batch_size=10000)
X_train, y_train = next(iter(traindataloader))
X_test , y_test  = next(iter(testdataloader))
X_train, y_train = X_train.detach().cpu().numpy(), y_train.detach().cpu().numpy()
X_test, y_test = X_test.detach().cpu().numpy(), y_test.detach().cpu().numpy()

### preprocessing

In [4]:
X_train = X_train.reshape(60000, 1, 28, 28)
X_test = X_test.reshape(10000, 1, 28, 28)

### initial labelled data
We initialize the labelled set with 100 balanced randomly sampled examples

In [5]:
initial_idx = np.array([],dtype=int)
for i in range(10):
    idx = np.random.choice(np.where(y_train==i)[0], size=10, replace=False)
    initial_idx = np.concatenate((initial_idx, idx))

X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

### initial unlabelled pool

In [6]:
# X_pool = np.delete(X_train, initial_idx, axis=0)
# y_pool = np.delete(y_train, initial_idx, axis=0)

X_pool = np.copy(X_train)
y_pool = np.copy(y_train)

## Query Strategies

### Uniform
All the acquisition function we will use will be compared to the uniform acquisition function $\mathbb{U}_{[0,1]}$ which will be our baseline that we would like to beat.

In [7]:
def uniform(learner, X, n_instances=1):
    query_idx = np.random.choice(range(len(X)), size=n_instances, replace=False)
    return query_idx, X[query_idx]

### Entropy
Our first acquisition function is the entropy:
$$ \mathbb{H} = - \sum_{c} p_c \log(p_c)$$
where $p_c$ is the probability predicted for class c. This is approximated by:
\begin{align}
p_c &= \frac{1}{T} \sum_t p_{c}^{(t)} 
\end{align}
where $p_{c}^{t}$ is the probability predicted for class c at the t th feedforward pass.

In [8]:
def max_entropy(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(20)])
    pc = outputs.mean(axis=0)
    acquisition = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]

In [9]:
def bald(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(20)])
    pc = outputs.mean(axis=0)
    H   = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    E_H = - np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)  # [batch size]
    acquisition = H - E_H
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]    

In [10]:
def save_list(input_list, name):
    with open("perf_lists_vae/" + name, 'w') as f:
        for val in input_list:
            f.write("%s\n" % val)

### Active Learning Procedure

In [15]:
def active_learning_procedure_vae(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=150,
                              n_instances=10):
    weights_location = 'weights.pt'
    vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
    vae.load_state_dict(torch.load(weights_location))

    def vaeNewSampleGenerator(vae, samples):    
        vae.eval()
        with torch.no_grad():
            new_samples, _, _ = vae(samples.reshape(samples.shape[0], -1))
            return new_samples.reshape(samples.shape)

    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    X_rolling, y_rolling = np.copy(X_initial), np.copy(y_initial)
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)
        
        new_samples = vaeNewSampleGenerator(vae, torch.tensor(X_pool[query_idx]))
        new_samples = new_samples.detach().cpu().numpy()
        new_samples = np.concatenate((new_samples, X_pool[query_idx]))
        new_labels = np.concatenate((y_pool[query_idx], y_pool[query_idx]))
        
        X_rolling, y_rolling = np.concatenate((X_rolling, new_samples), axis=0), np.concatenate((y_rolling, new_labels))
        learner.fit(X_rolling, y_rolling)
        
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
    return perf_hist

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
bald_vae_perf_hist = active_learning_procedure_vae(bald,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,
                                            n_instances=100)
save_list(bald_vae_perf_hist, "bald_vae_perf_hist")



Accuracy after query 1: 0.6283




Accuracy after query 2: 0.6191




Accuracy after query 3: 0.6698




Accuracy after query 4: 0.6991




Accuracy after query 5: 0.7457




Accuracy after query 6: 0.7524




Accuracy after query 7: 0.7890




Accuracy after query 8: 0.7995




Accuracy after query 9: 0.8167




Accuracy after query 10: 0.8007




Accuracy after query 11: 0.7953




Accuracy after query 12: 0.8141




Accuracy after query 13: 0.8163




Accuracy after query 14: 0.8169




Accuracy after query 15: 0.8306




Accuracy after query 16: 0.8279




Accuracy after query 17: 0.8367




Accuracy after query 18: 0.8577




Accuracy after query 19: 0.8620




Accuracy after query 20: 0.8601




Accuracy after query 21: 0.8760




Accuracy after query 22: 0.8817




Accuracy after query 23: 0.8792




Accuracy after query 24: 0.8857




Accuracy after query 25: 0.8953




Accuracy after query 26: 0.9107




Accuracy after query 27: 0.9017




Accuracy after query 28: 0.9099




Accuracy after query 29: 0.9073




Accuracy after query 30: 0.9143




Accuracy after query 31: 0.9155




Accuracy after query 32: 0.9147




Accuracy after query 33: 0.9173




Accuracy after query 34: 0.9264




Accuracy after query 35: 0.9266




Accuracy after query 36: 0.9272




Accuracy after query 37: 0.9332




Accuracy after query 38: 0.9301




Accuracy after query 39: 0.9350




Accuracy after query 40: 0.9400




Accuracy after query 41: 0.9335




Accuracy after query 42: 0.9434




Accuracy after query 43: 0.9376




Accuracy after query 44: 0.9433




Accuracy after query 45: 0.9045




Accuracy after query 46: 0.9306




Accuracy after query 47: 0.9378




Accuracy after query 48: 0.9422




Accuracy after query 49: 0.9419




Accuracy after query 50: 0.9453




Accuracy after query 51: 0.9351




Accuracy after query 52: 0.9461




Accuracy after query 53: 0.9532




Accuracy after query 54: 0.9512




Accuracy after query 55: 0.9423




Accuracy after query 56: 0.9491




Accuracy after query 57: 0.9556




Accuracy after query 58: 0.9498




Accuracy after query 59: 0.9521




Accuracy after query 60: 0.9556




Accuracy after query 61: 0.9632




Accuracy after query 62: 0.9500




Accuracy after query 63: 0.9605




Accuracy after query 64: 0.9602




Accuracy after query 65: 0.9585




Accuracy after query 66: 0.9524




Accuracy after query 67: 0.9643




Accuracy after query 68: 0.9568




Accuracy after query 69: 0.9635




Accuracy after query 70: 0.9652




Accuracy after query 71: 0.9612




Accuracy after query 72: 0.9486




Accuracy after query 73: 0.9618




Accuracy after query 74: 0.9616




Accuracy after query 75: 0.9671




Accuracy after query 76: 0.9649




Accuracy after query 77: 0.9628




Accuracy after query 78: 0.9646




Accuracy after query 79: 0.9627




Accuracy after query 80: 0.9549




Accuracy after query 81: 0.9625




Accuracy after query 82: 0.9681




Accuracy after query 83: 0.9622




Accuracy after query 84: 0.9733




Accuracy after query 85: 0.9685




Accuracy after query 86: 0.9648




Accuracy after query 87: 0.9717




Accuracy after query 88: 0.9685




Accuracy after query 89: 0.9728




Accuracy after query 90: 0.9696




Accuracy after query 91: 0.9724




Accuracy after query 92: 0.9765




Accuracy after query 93: 0.9724




Accuracy after query 94: 0.9751




Accuracy after query 95: 0.9729




Accuracy after query 96: 0.9711




Accuracy after query 97: 0.9707




Accuracy after query 98: 0.9721




Accuracy after query 99: 0.9734




Accuracy after query 100: 0.9744


In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
entropy_vae_perf_hist = active_learning_procedure_generative(max_entropy,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,
                                            n_instances=100)
save_list(entropy_vae_perf_hist, "entropy_vae_perf_hist")



Accuracy after query 1: 0.6873




Accuracy after query 2: 0.6613




Accuracy after query 3: 0.6846




Accuracy after query 4: 0.7370




Accuracy after query 5: 0.7375




Accuracy after query 6: 0.6703




Accuracy after query 7: 0.6969




Accuracy after query 8: 0.7615




Accuracy after query 9: 0.7716




Accuracy after query 10: 0.7844




Accuracy after query 11: 0.8008




Accuracy after query 12: 0.8070




Accuracy after query 13: 0.7930




Accuracy after query 14: 0.8342




Accuracy after query 15: 0.8177




Accuracy after query 16: 0.8503




Accuracy after query 17: 0.8498




Accuracy after query 18: 0.8578




Accuracy after query 19: 0.8469




Accuracy after query 20: 0.8683




Accuracy after query 21: 0.8689




Accuracy after query 22: 0.8877




Accuracy after query 23: 0.8913




Accuracy after query 24: 0.8999




Accuracy after query 25: 0.9011




Accuracy after query 26: 0.8950




Accuracy after query 27: 0.9014




Accuracy after query 28: 0.9156




Accuracy after query 29: 0.9044




Accuracy after query 30: 0.9122




Accuracy after query 31: 0.9155




Accuracy after query 32: 0.9068




Accuracy after query 33: 0.9247




Accuracy after query 34: 0.9073




Accuracy after query 35: 0.9162




Accuracy after query 36: 0.9376




Accuracy after query 37: 0.9227




Accuracy after query 38: 0.9467




Accuracy after query 39: 0.9370




Accuracy after query 40: 0.9301




Accuracy after query 41: 0.9357




Accuracy after query 42: 0.9488




Accuracy after query 43: 0.9457




Accuracy after query 44: 0.9502




Accuracy after query 45: 0.9379




Accuracy after query 46: 0.9551




Accuracy after query 47: 0.9510




Accuracy after query 48: 0.9430




Accuracy after query 49: 0.9508




Accuracy after query 50: 0.9497




Accuracy after query 51: 0.9569




Accuracy after query 52: 0.9511




Accuracy after query 53: 0.9573




Accuracy after query 54: 0.9527




Accuracy after query 55: 0.9564




Accuracy after query 56: 0.9583




Accuracy after query 57: 0.9556




Accuracy after query 58: 0.9633




Accuracy after query 59: 0.9476




Accuracy after query 60: 0.9629




Accuracy after query 61: 0.9587




Accuracy after query 62: 0.9571




Accuracy after query 63: 0.9600




Accuracy after query 64: 0.9657




Accuracy after query 65: 0.9611




Accuracy after query 66: 0.9598




Accuracy after query 67: 0.9573




Accuracy after query 68: 0.9661




Accuracy after query 69: 0.9592




Accuracy after query 70: 0.9640




Accuracy after query 71: 0.9621




Accuracy after query 72: 0.9627




Accuracy after query 73: 0.9648




Accuracy after query 74: 0.9702




Accuracy after query 75: 0.9689




Accuracy after query 76: 0.9572




Accuracy after query 77: 0.9667




Accuracy after query 78: 0.9673




Accuracy after query 79: 0.9697




Accuracy after query 80: 0.9722




Accuracy after query 81: 0.9679




Accuracy after query 82: 0.9683




Accuracy after query 83: 0.9670




Accuracy after query 84: 0.9734




Accuracy after query 85: 0.9690




Accuracy after query 86: 0.9728




Accuracy after query 87: 0.9717




Accuracy after query 88: 0.9766




Accuracy after query 89: 0.9728




Accuracy after query 90: 0.9673




Accuracy after query 91: 0.9725




Accuracy after query 92: 0.9664




Accuracy after query 93: 0.9723




Accuracy after query 94: 0.9632




Accuracy after query 95: 0.9703




Accuracy after query 96: 0.9722




Accuracy after query 97: 0.9681




Accuracy after query 98: 0.9673




Accuracy after query 99: 0.9685




Accuracy after query 100: 0.9705


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
uniform_vae_perf_hist = active_learning_procedure_generative(uniform,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,
                                            n_instances=100)
save_list(uniform_vae_perf_hist, "uniform_vae_perf_hist")