This is an implementation of the paper Deep Bayesian Active Learning with Image Data using PyTorch and modAL. 

modAL is an active learning framework for Python3, designed with modularity, flexibility and extensibility in mind. Built on top of scikit-learn, it allows you to rapidly create active learning workflows with nearly complete freedom. What is more, you can easily replace parts with your custom built solutions, allowing you to design novel algorithms with ease.

Since modAL only supports sklearn models, we will also use [skorch](https://skorch.readthedocs.io/en/stable/), a scikit-learn compatible neural network library that wraps PyTorch. 

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
PATH = '..\\'
# import warnings
# warnings.filterwarnings("ignore")

In [3]:
# !pip install skorch modAL

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import os

class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

  warn(f"Failed to load image Python extension: {e}")


In [5]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from skorch import NeuralNetClassifier
from modAL.models import ActiveLearner

# from VAE import VAE

### architecture of the network we will be using

We will use the architecture described in the paper.

In [6]:
# class CNN(nn.Module):
#     def __init__(self,):
#         super(CNN, self).__init__()
#         self.convs = nn.Sequential(
#                                 nn.Conv2d(1,32,4),
#                                 nn.ReLU(),
#                                 nn.Conv2d(32,32,4),
#                                 nn.ReLU(),
#                                 nn.MaxPool2d(2),
#                                 nn.Dropout(0.25)
#         )
#         self.fcs = nn.Sequential(
#                                 nn.Linear(11*11*32,128),
#                                 nn.ReLU(),
#                                 nn.Dropout(0.5),
#                                 nn.Linear(128,10),
#         )

#     def forward(self, x):
#         out = x
#         out = self.convs(out)
#         out = out.view(-1,11*11*32)
#         out = self.fcs(out)
#         return out
input_dim, input_height, input_width = 1, 28, 28
class lenet(nn.Module):  
    def __init__(self):
        super(lenet, self).__init__()
        self.input_height = input_height
        self.input_width = input_width
        self.input_dim = input_dim
        self.class_num = 10

        self.conv1 = nn.Conv2d(self.input_dim, 6, (5, 5), padding=2)
        self.conv2 = nn.Conv2d(6, 16, (5, 5))
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, self.class_num)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

### read training data

In [7]:
mnist_train = MNIST(PATH, train=True, download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
mnist_test  = MNIST(PATH, train=False,download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
traindataloader = DataLoader(mnist_train, shuffle=True, batch_size=60000)
testdataloader  = DataLoader(mnist_test , shuffle=True, batch_size=10000)
X_train, y_train = next(iter(traindataloader))
X_test , y_test  = next(iter(testdataloader))
X_train, y_train = X_train.detach().cpu().numpy(), y_train.detach().cpu().numpy()
X_test, y_test = X_test.detach().cpu().numpy(), y_test.detach().cpu().numpy()

### preprocessing

In [8]:
X_train = X_train.reshape(60000, 1, 28, 28)
X_test = X_test.reshape(10000, 1, 28, 28)

### initial labelled data
We initialize the labelled set with 100 balanced randomly sampled examples

In [9]:
initial_idx = np.array([],dtype=int)
for i in range(10):
    idx = np.random.choice(np.where(y_train==i)[0], size=10, replace=False)
    initial_idx = np.concatenate((initial_idx, idx))

X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

### initial unlabelled pool

In [10]:
# X_pool = np.delete(X_train, initial_idx, axis=0)
# y_pool = np.delete(y_train, initial_idx, axis=0)

X_pool = np.copy(X_train)
y_pool = np.copy(y_train)

## Query Strategies

### Uniform
All the acquisition function we will use will be compared to the uniform acquisition function $\mathbb{U}_{[0,1]}$ which will be our baseline that we would like to beat.

In [11]:
def uniform(learner, X, n_instances=1):
    query_idx = np.random.choice(range(len(X)), size=n_instances, replace=False)
    return query_idx, X[query_idx]

### Entropy
Our first acquisition function is the entropy:
$$ \mathbb{H} = - \sum_{c} p_c \log(p_c)$$
where $p_c$ is the probability predicted for class c. This is approximated by:
\begin{align}
p_c &= \frac{1}{T} \sum_t p_{c}^{(t)} 
\end{align}
where $p_{c}^{t}$ is the probability predicted for class c at the t th feedforward pass.

In [12]:
def max_entropy(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(20)])
    pc = outputs.mean(axis=0)
    acquisition = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]

In [13]:
def bald(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(20)])
    pc = outputs.mean(axis=0)
    H   = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    E_H = - np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)  # [batch size]
    acquisition = H - E_H
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]    

In [14]:
os.makedirs(f"{PATH}/perf_lists_vae/",exist_ok=True)
def save_list(input_list, name):
    with open(f"{PATH}/perf_lists_vae/" + name, 'w') as f:
        for val in input_list:
            f.write("%s\n" % val)

### Active Learning Procedure

In [46]:
X_roll1, X_roll2 = None, None
def active_learning_procedure_vae(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=150,
                              n_instances=10):
    weights_location = f'{PATH}/weights.pt'
    vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
    vae.load_state_dict(torch.load(weights_location))
    global X_roll1, X_roll2
    def vaeNewSampleGenerator(vae, samples):    
        vae.eval()
        with torch.no_grad():
            new_samples, _, _ = vae(samples.reshape(samples.shape[0], -1))
            return new_samples.reshape(samples.shape)

    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    X_rolling, y_rolling = np.copy(X_initial), np.copy(y_initial)
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)
        print(query_idx.shape)
        new_samples = vaeNewSampleGenerator(vae, torch.tensor(X_pool[query_idx]))
        new_samples = new_samples.detach().cpu().numpy()
        new_samples = np.concatenate((new_samples, X_pool[query_idx]))
        new_labels = np.concatenate((y_pool[query_idx], y_pool[query_idx]))
        
        X_rolling, y_rolling = np.concatenate((X_rolling, new_samples), axis=0), np.concatenate((y_rolling, new_labels))
        if index == 0:
            print("0")
            X_roll1 = np.copy(X_rolling[-90:])
        if index == 1:
            X_roll2 = np.copy(X_rolling[-90:])
        learner.fit(X_rolling, y_rolling)
        
        if index == 1:
            break
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
    return perf_hist

In [47]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
bald_vae_perf_hist = active_learning_procedure_vae(bald,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,
                                            n_instances=100)
save_list(bald_vae_perf_hist, "bald_vae_perf_hist")

(100,)




0
Accuracy after query 1: 0.8192
(100,)




In [37]:
# 100 org -> 100 new samples -> 100 X_pool -> 100 new samples -> 100 X_pool

In [43]:
np.sum((X_roll1-X_roll2)**2)

102764.68

In [50]:
tmp = np.array([1, 2, 3, 4, 5])
idx = np.array([0, 2])
tmp2 = tmp[idx]
tmp[0] = -1
tmp2, tmp

(array([1, 3]), array([-1,  2,  3,  4,  5]))

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
entropy_vae_perf_hist = active_learning_procedure_vae(max_entropy,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,
                                            n_instances=100)
save_list(entropy_vae_perf_hist, "entropy_vae_perf_hist")



Accuracy after query 1: 0.8481




Accuracy after query 2: 0.9091




Accuracy after query 3: 0.8653




Accuracy after query 4: 0.9174




Accuracy after query 5: 0.9335




Accuracy after query 6: 0.9463




Accuracy after query 7: 0.9295




Accuracy after query 8: 0.9392




Accuracy after query 9: 0.9569




Accuracy after query 10: 0.9568




Accuracy after query 11: 0.9434




Accuracy after query 12: 0.9634




Accuracy after query 13: 0.9633




Accuracy after query 14: 0.9637




Accuracy after query 15: 0.9652




Accuracy after query 16: 0.9663




Accuracy after query 17: 0.9588




Accuracy after query 18: 0.9686




Accuracy after query 19: 0.9669




Accuracy after query 20: 0.9703




Accuracy after query 21: 0.9732




Accuracy after query 22: 0.9731




Accuracy after query 23: 0.9643




Accuracy after query 24: 0.9656




Accuracy after query 25: 0.9688




Accuracy after query 26: 0.9674




Accuracy after query 27: 0.9688




Accuracy after query 28: 0.9730




Accuracy after query 29: 0.9774




Accuracy after query 30: 0.9753




Accuracy after query 31: 0.9800




Accuracy after query 32: 0.9741




Accuracy after query 33: 0.9567




Accuracy after query 34: 0.9734




Accuracy after query 35: 0.9794




Accuracy after query 36: 0.9761




Accuracy after query 37: 0.9812




Accuracy after query 38: 0.9792




Accuracy after query 39: 0.9616




Accuracy after query 40: 0.9537




Accuracy after query 41: 0.9815




Accuracy after query 42: 0.9796




Accuracy after query 43: 0.9763




Accuracy after query 44: 0.9794




Accuracy after query 45: 0.9788




Accuracy after query 46: 0.9776




Accuracy after query 47: 0.9801




Accuracy after query 48: 0.9807




Accuracy after query 49: 0.9780




Accuracy after query 50: 0.9795




Accuracy after query 51: 0.9829




Accuracy after query 52: 0.9779




Accuracy after query 53: 0.9758




Accuracy after query 54: 0.9830




Accuracy after query 55: 0.9819




Accuracy after query 56: 0.9817




Accuracy after query 57: 0.9802




Accuracy after query 58: 0.9798




Accuracy after query 59: 0.9794




Accuracy after query 60: 0.9801




Accuracy after query 61: 0.9723




Accuracy after query 62: 0.9746




Accuracy after query 63: 0.9824




Accuracy after query 64: 0.9798




Accuracy after query 65: 0.9846




Accuracy after query 66: 0.9795




Accuracy after query 67: 0.9808




Accuracy after query 68: 0.9819




Accuracy after query 69: 0.9788




Accuracy after query 70: 0.9834




Accuracy after query 71: 0.9842




Accuracy after query 72: 0.9823




Accuracy after query 73: 0.9811




Accuracy after query 74: 0.9833




Accuracy after query 75: 0.9840




Accuracy after query 76: 0.9802




Accuracy after query 77: 0.9790




Accuracy after query 78: 0.9820




Accuracy after query 79: 0.9824




Accuracy after query 80: 0.9833




Accuracy after query 81: 0.9857




Accuracy after query 82: 0.9831




Accuracy after query 83: 0.9757




Accuracy after query 84: 0.9817




Accuracy after query 85: 0.9853




Accuracy after query 86: 0.9844




Accuracy after query 87: 0.9850




Accuracy after query 88: 0.9821




Accuracy after query 89: 0.9799




Accuracy after query 90: 0.9821




Accuracy after query 91: 0.9824




Accuracy after query 92: 0.9831




Accuracy after query 93: 0.9841




Accuracy after query 94: 0.9836




Accuracy after query 95: 0.9848




Accuracy after query 96: 0.9872




Accuracy after query 97: 0.9835




Accuracy after query 98: 0.9842




Accuracy after query 99: 0.9795




Accuracy after query 100: 0.9820




Accuracy after query 101: 0.9854




Accuracy after query 102: 0.9849




Accuracy after query 103: 0.9848




Accuracy after query 104: 0.9828




Accuracy after query 105: 0.9834




Accuracy after query 106: 0.9850




Accuracy after query 107: 0.9837




Accuracy after query 108: 0.9832




Accuracy after query 109: 0.9830




Accuracy after query 110: 0.9838




Accuracy after query 111: 0.9815




Accuracy after query 112: 0.9850




Accuracy after query 113: 0.9833




Accuracy after query 114: 0.9855




Accuracy after query 115: 0.9837




Accuracy after query 116: 0.9856




Accuracy after query 117: 0.9840




Accuracy after query 118: 0.9832




Accuracy after query 119: 0.9854




Accuracy after query 120: 0.9816




Accuracy after query 121: 0.9848




Accuracy after query 122: 0.9852




Accuracy after query 123: 0.9850




Accuracy after query 124: 0.9835




Accuracy after query 125: 0.9842




Accuracy after query 126: 0.9839




Accuracy after query 127: 0.9852




Accuracy after query 128: 0.9844




Accuracy after query 129: 0.9847




Accuracy after query 130: 0.9858




Accuracy after query 131: 0.9856




Accuracy after query 132: 0.9844




Accuracy after query 133: 0.9861




Accuracy after query 134: 0.9849




Accuracy after query 135: 0.9846




Accuracy after query 136: 0.9853




Accuracy after query 137: 0.9840




Accuracy after query 138: 0.9830




Accuracy after query 139: 0.9837




Accuracy after query 140: 0.9842




Accuracy after query 141: 0.9868




Accuracy after query 142: 0.9841




Accuracy after query 143: 0.9837




Accuracy after query 144: 0.9831




Accuracy after query 145: 0.9862




Accuracy after query 146: 0.9851




Accuracy after query 147: 0.9871




Accuracy after query 148: 0.9841




Accuracy after query 149: 0.9847




Accuracy after query 150: 0.9832


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
uniform_vae_perf_hist = active_learning_procedure_vae(uniform,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,
                                            n_instances=100)
save_list(uniform_vae_perf_hist, "uniform_vae_perf_hist")

Accuracy after query 1: 0.8598
Accuracy after query 2: 0.8223
Accuracy after query 3: 0.8677
Accuracy after query 4: 0.9046
Accuracy after query 5: 0.9044
Accuracy after query 6: 0.9063
Accuracy after query 7: 0.8286
Accuracy after query 8: 0.9240
Accuracy after query 9: 0.9201
Accuracy after query 10: 0.9285
Accuracy after query 11: 0.9189
Accuracy after query 12: 0.9180
Accuracy after query 13: 0.9354
Accuracy after query 14: 0.9283
Accuracy after query 15: 0.9298
Accuracy after query 16: 0.9448
Accuracy after query 17: 0.9436
Accuracy after query 18: 0.9398
Accuracy after query 19: 0.9453
Accuracy after query 20: 0.9364
Accuracy after query 21: 0.9306
Accuracy after query 22: 0.9422
Accuracy after query 23: 0.9461
Accuracy after query 24: 0.9515
Accuracy after query 25: 0.9242
Accuracy after query 26: 0.9404
Accuracy after query 27: 0.9406
Accuracy after query 28: 0.9497
Accuracy after query 29: 0.9419
Accuracy after query 30: 0.9409
Accuracy after query 31: 0.9462
Accuracy after qu