This is an implementation of the paper Deep Bayesian Active Learning with Image Data using PyTorch and modAL. 

modAL is an active learning framework for Python3, designed with modularity, flexibility and extensibility in mind. Built on top of scikit-learn, it allows you to rapidly create active learning workflows with nearly complete freedom. What is more, you can easily replace parts with your custom built solutions, allowing you to design novel algorithms with ease.

Since modAL only supports sklearn models, we will also use [skorch](https://skorch.readthedocs.io/en/stable/), a scikit-learn compatible neural network library that wraps PyTorch. 

In [1]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from skorch import NeuralNetClassifier
from modAL.models import ActiveLearner
from torch.autograd import Variable

from VAE import VAE

  warn(f"Failed to load image Python extension: {e}")


### architecture of the network we will be using

We will use the architecture described in the paper.

In [2]:
class CNN(nn.Module):
    def __init__(self,):
        super(CNN, self).__init__()
        self.convs = nn.Sequential(
                                nn.Conv2d(1,32,4),
                                nn.ReLU(),
                                nn.Conv2d(32,32,4),
                                nn.ReLU(),
                                nn.MaxPool2d(2),
                                nn.Dropout(0.25)
        )
        self.fcs = nn.Sequential(
                                nn.Linear(11*11*32,128),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(128,10),
        )

    def forward(self, x):
        out = x
        out = self.convs(out)
        out = out.view(-1,11*11*32)
        out = self.fcs(out)
        return out

### read training data

In [3]:
mnist_train = MNIST('.', train=True, download=True, transform=ToTensor())
mnist_test  = MNIST('.', train=False,download=True, transform=ToTensor())
traindataloader = DataLoader(mnist_train, shuffle=True, batch_size=60000)
testdataloader  = DataLoader(mnist_test , shuffle=True, batch_size=10000)
X_train, y_train = next(iter(traindataloader))
X_test , y_test  = next(iter(testdataloader))
X_train, y_train = X_train.detach().cpu().numpy(), y_train.detach().cpu().numpy()
X_test, y_test = X_test.detach().cpu().numpy(), y_test.detach().cpu().numpy()

### preprocessing

In [4]:
X_train = X_train.reshape(60000, 1, 28, 28)
X_test = X_test.reshape(10000, 1, 28, 28)

### initial labelled data
We initialize the labelled set with 20 balanced randomly sampled examples

In [5]:
initial_idx = np.array([],dtype=int)
for i in range(10):
    idx = np.random.choice(np.where(y_train==i)[0], size=2, replace=False)
    initial_idx = np.concatenate((initial_idx, idx))

X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

### initial unlabelled pool

In [6]:
X_pool = np.delete(X_train, initial_idx, axis=0)
y_pool = np.delete(y_train, initial_idx, axis=0)

## Query Strategies

### Uniform
All the acquisition function we will use will be compared to the uniform acquisition function $\mathbb{U}_{[0,1]}$ which will be our baseline that we would like to beat.

In [7]:
def uniform(learner, X, n_instances=1):
    query_idx = np.random.choice(range(len(X)), size=n_instances, replace=False)
    return query_idx, X[query_idx]

### Entropy
Our first acquisition function is the entropy:
$$ \mathbb{H} = - \sum_{c} p_c \log(p_c)$$
where $p_c$ is the probability predicted for class c. This is approximated by:
\begin{align}
p_c &= \frac{1}{T} \sum_t p_{c}^{(t)} 
\end{align}
where $p_{c}^{t}$ is the probability predicted for class c at the t th feedforward pass.

In [8]:
def max_entropy(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(100)])
    pc = outputs.mean(axis=0)
    acquisition = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]

In [9]:
def bald(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(100)])
    pc = outputs.mean(axis=0)
    H   = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    E_H = - np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)  # [batch size]
    acquisition = H - E_H
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]    

In [10]:
def save_list(input_list, name):
    with open("perf_lists0/" + name, 'w') as f:
        for val in input_list:
            f.write("%s\n" % val)

### Active Learning Procedure

In [11]:
def active_learning_procedure_generative(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=100,
                              n_instances=10):
    
    enc  = torch.load('./checkpoints/enc.pt', map_location=torch.device('cuda'))
    dec  = torch.load('./checkpoints/dec.pt', map_location=torch.device('cuda'))
    disc = torch.load('./checkpoints/disc.pt', map_location=torch.device('cuda'))
    
    def sample(mu, logvar):
        std = torch.exp(0.5*logvar)
        rand_z_score = torch.randn_like(std)
        return mu + rand_z_score*std
    
    def vaeGanNewSampleGenerator(images):
        images = Variable(images)
        mu, logvar = enc(images)
        z = sample(mu, logvar)
        reconstructions = dec(z)
        reconstructions = reconstructions.reshape(-1, 1, 28, 28)
        return reconstructions

    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)
        
        new_samples = vaeGanNewSampleGenerator(torch.tensor(X_pool[query_idx]).to(torch.device("cuda")))
        new_samples = new_samples.detach().cpu().numpy()
        new_samples = np.concatenate((new_samples, X_pool[query_idx]))
        new_labels = np.concatenate((y_pool[query_idx], y_pool[query_idx]))
        learner.teach(new_samples, new_labels)
        
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
    return perf_hist

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(CNN,
                                max_epochs=50,
                                batch_size=128,
                                lr=0.001,
                                optimizer=torch.optim.Adam,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
bald_gan_std0_perf_hist = active_learning_procedure_generative(bald,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,)
save_list(bald_gan_std0_perf_hist, "bald_gan_std0_perf_hist")



Accuracy after query 1: 0.5906
Accuracy after query 2: 0.5887
Accuracy after query 3: 0.5968
Accuracy after query 4: 0.6089
Accuracy after query 5: 0.6677
Accuracy after query 6: 0.6494
Accuracy after query 7: 0.6752
Accuracy after query 8: 0.6691
Accuracy after query 9: 0.6893
Accuracy after query 10: 0.7631
Accuracy after query 11: 0.7265
Accuracy after query 12: 0.7747
Accuracy after query 13: 0.7785
Accuracy after query 14: 0.7744
Accuracy after query 15: 0.8104
Accuracy after query 16: 0.8043
Accuracy after query 17: 0.8134
Accuracy after query 18: 0.8162
Accuracy after query 19: 0.8281
Accuracy after query 20: 0.8345
Accuracy after query 21: 0.8516
Accuracy after query 22: 0.8541
Accuracy after query 23: 0.8507
Accuracy after query 24: 0.8585
Accuracy after query 25: 0.8728
Accuracy after query 26: 0.8643
Accuracy after query 27: 0.8673
Accuracy after query 28: 0.8889
Accuracy after query 29: 0.8809
Accuracy after query 30: 0.8873
Accuracy after query 31: 0.8673
Accuracy after qu

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(CNN,
                                max_epochs=50,
                                batch_size=128,
                                lr=0.001,
                                optimizer=torch.optim.Adam,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
entropy_gan_perf_hist = active_learning_procedure_generative(max_entropy,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,)
save_list(entropy_gan_perf_hist, "entropy_gan_perf_hist")



Accuracy after query 1: 0.6143
Accuracy after query 2: 0.6402
Accuracy after query 3: 0.6397
Accuracy after query 4: 0.6014
Accuracy after query 5: 0.6391
Accuracy after query 6: 0.6841
Accuracy after query 7: 0.6896
Accuracy after query 8: 0.7289
Accuracy after query 9: 0.7140
Accuracy after query 10: 0.7216
Accuracy after query 11: 0.7533
Accuracy after query 12: 0.7862
Accuracy after query 13: 0.7803
Accuracy after query 14: 0.7703
Accuracy after query 15: 0.8221
Accuracy after query 16: 0.8161
Accuracy after query 17: 0.8508
Accuracy after query 18: 0.8635
Accuracy after query 19: 0.8579
Accuracy after query 20: 0.8608
Accuracy after query 21: 0.8768
Accuracy after query 22: 0.8733
Accuracy after query 23: 0.8833
Accuracy after query 24: 0.9051
Accuracy after query 25: 0.8845
Accuracy after query 26: 0.8952
Accuracy after query 27: 0.8818
Accuracy after query 28: 0.8882
Accuracy after query 29: 0.8978
Accuracy after query 30: 0.9102
Accuracy after query 31: 0.9167
Accuracy after qu