This is an implementation of the paper Deep Bayesian Active Learning with Image Data using PyTorch and modAL. 

modAL is an active learning framework for Python3, designed with modularity, flexibility and extensibility in mind. Built on top of scikit-learn, it allows you to rapidly create active learning workflows with nearly complete freedom. What is more, you can easily replace parts with your custom built solutions, allowing you to design novel algorithms with ease.

Since modAL only supports sklearn models, we will also use [skorch](https://skorch.readthedocs.io/en/stable/), a scikit-learn compatible neural network library that wraps PyTorch. 

In [1]:
!pip install -U skorch
!pip install modAL
!pip install lightning-bolts

Collecting skorch
  Downloading skorch-0.11.0-py3-none-any.whl (155 kB)
[?25l[K     |██▏                             | 10 kB 25.4 MB/s eta 0:00:01[K     |████▎                           | 20 kB 10.4 MB/s eta 0:00:01[K     |██████▍                         | 30 kB 5.8 MB/s eta 0:00:01[K     |████████▌                       | 40 kB 6.2 MB/s eta 0:00:01[K     |██████████▋                     | 51 kB 3.5 MB/s eta 0:00:01[K     |████████████▊                   | 61 kB 4.1 MB/s eta 0:00:01[K     |██████████████▉                 | 71 kB 4.3 MB/s eta 0:00:01[K     |█████████████████               | 81 kB 4.4 MB/s eta 0:00:01[K     |███████████████████             | 92 kB 4.9 MB/s eta 0:00:01[K     |█████████████████████▏          | 102 kB 4.1 MB/s eta 0:00:01[K     |███████████████████████▎        | 112 kB 4.1 MB/s eta 0:00:01[K     |█████████████████████████▍      | 122 kB 4.1 MB/s eta 0:00:01[K     |███████████████████████████▌    | 133 kB 4.1 MB/s eta 0:00:01[K  

In [2]:
!mkdir perf_lists_cifar_vae

In [3]:
PATH = './'

from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
from pl_bolts.models.autoencoders import VAE
from pytorch_lightning import Trainer
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
import torch.nn.functional as F

vae = VAE(input_height=32, first_conv=False)
print(VAE.pretrained_weights_available())

vae = vae.from_pretrained(f'cifar10-resnet18')

vae.freeze()


dm = CIFAR10DataModule(PATH, normalize=True)
dm.prepare_data()
dm.setup("fit")
dataloader = dm.train_dataloader()


mean = torch.tensor(dm.default_transforms().transforms[1].mean)
std = torch.tensor(dm.default_transforms().transforms[1].std)

normalize = transforms.Normalize(mean, std)
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())

def get_new_images(X):
    vae.eval()
    X_hat = vae(normalize(X))
    return unnormalize(X_hat).detach()

['cifar10-resnet18', 'stl10-resnet18']


Downloading: "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/vae-cifar10/checkpoints/epoch%3D89.ckpt" to /root/.cache/torch/hub/checkpoints/epoch%3D89.ckpt


  0%|          | 0.00/230M [00:00<?, ?B/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


  "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
  "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."


In [4]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import CIFAR10
from skorch import NeuralNetClassifier
from modAL.models import ActiveLearner
from torchvision import transforms
# from VAE import VAE

### architecture of the network we will be using

We will use the architecture described in the paper.

In [5]:
# class CNN(nn.Module):
#     def __init__(self,):
#         super(CNN, self).__init__()
#         self.convs = nn.Sequential(
#                                 nn.Conv2d(1,32,4),
#                                 nn.ReLU(),
#                                 nn.Conv2d(32,32,4),
#                                 nn.ReLU(),
#                                 nn.MaxPool2d(2),
#                                 nn.Dropout(0.25)
#         )
#         self.fcs = nn.Sequential(
#                                 nn.Linear(11*11*32,128),
#                                 nn.ReLU(),
#                                 nn.Dropout(0.5),
#                                 nn.Linear(128,10),
#         )

#     def forward(self, x):
#         out = x
#         out = self.convs(out)
#         out = out.view(-1,11*11*32)
#         out = self.fcs(out)
#         return out

input_dim, input_height, input_width = 3, 28, 28
class lenet(nn.Module):  
    def __init__(self):
        super(lenet, self).__init__()
        self.input_height = input_height
        self.input_width = input_width
        self.input_dim = input_dim
        self.class_num = 10

        self.conv1 = nn.Conv2d(self.input_dim, 6, (5, 5), padding=2)
        self.conv2 = nn.Conv2d(6, 16, (5, 5))
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, self.class_num)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

# class CNN(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.conv1 = nn.Conv2d(3, 6, 5)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.conv2 = nn.Conv2d(6, 16, 5)
#         self.fc1 = nn.Linear(16 * 5 * 5, 120)
#         self.fc2 = nn.Linear(120, 84)
#         self.fc3 = nn.Linear(84, 10)

#     def forward(self, x):
#         x = self.pool(F.relu(self.conv1(x)))
#         x = self.pool(F.relu(self.conv2(x)))
#         x = torch.flatten(x, 1) # flatten all dimensions except batch
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x

### read training data

In [6]:
cifar10_train = CIFAR10('.', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Resize(size=(28, 28))]))
cifar10_test  = CIFAR10('.', train=False,download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Resize(size=(28, 28))]))
traindataloader = DataLoader(cifar10_train, shuffle=True, batch_size=60000)
testdataloader  = DataLoader(cifar10_train, shuffle=True, batch_size=10000)
X_train, y_train = next(iter(traindataloader))
X_test , y_test  = next(iter(testdataloader))
X_train, y_train = X_train.detach().cpu().numpy(), y_train.detach().cpu().numpy()
X_test, y_test = X_test.detach().cpu().numpy(), y_test.detach().cpu().numpy()

Files already downloaded and verified
Files already downloaded and verified


In [7]:
X_train.shape

(50000, 3, 28, 28)

### preprocessing

In [8]:
# X_train = X_train.reshape(60000, 1, 28, 28)
# X_test = X_test.reshape(10000, 1, 28, 28)

### initial labelled data
We initialize the labelled set with 100 balanced randomly sampled examples

In [9]:
initial_idx = np.array([],dtype=int)
for i in range(10):
    idx = np.random.choice(np.where(y_train==i)[0], size=10, replace=False)
    initial_idx = np.concatenate((initial_idx, idx))

X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

### initial unlabelled pool

In [10]:
# X_pool = np.delete(X_train, initial_idx, axis=0)
# y_pool = np.delete(y_train, initial_idx, axis=0)
X_pool = np.copy(X_train)
y_pool = np.copy(y_train)

## Query Strategies

### Uniform
All the acquisition function we will use will be compared to the uniform acquisition function $\mathbb{U}_{[0,1]}$ which will be our baseline that we would like to beat.

In [11]:
def uniform(learner, X, n_instances=1):
    query_idx = np.random.choice(range(len(X)), size=n_instances, replace=False)
    return query_idx, X[query_idx]

### Entropy
Our first acquisition function is the entropy:
$$ \mathbb{H} = - \sum_{c} p_c \log(p_c)$$
where $p_c$ is the probability predicted for class c. This is approximated by:
\begin{align}
p_c &= \frac{1}{T} \sum_t p_{c}^{(t)} 
\end{align}
where $p_{c}^{t}$ is the probability predicted for class c at the t th feedforward pass.

In [12]:
def max_entropy(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(20)])
    pc = outputs.mean(axis=0)
    acquisition = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]

In [13]:
def bald(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(20)])
    pc = outputs.mean(axis=0)
    H   = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    E_H = - np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)  # [batch size]
    acquisition = H - E_H
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]    

In [14]:
def save_list(input_list, name):
    with open("perf_lists_cifar_vae/" + name, 'w') as f:
        for val in input_list:
            f.write("%s\n" % val)

### Active Learning Procedure

In [15]:
def active_learning_procedure_vae(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=150,
                              n_instances=100):
    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    X_rolling, y_rolling = np.copy(X_initial), np.copy(y_initial)
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)

        new_samples = get_new_images(torch.tensor(X_pool[query_idx]))
        new_samples = F.interpolate(new_samples, size=28)
        new_samples = new_samples.detach().cpu().numpy()
        new_samples = np.concatenate((new_samples, X_pool[query_idx]))
        new_labels = np.concatenate((y_pool[query_idx], y_pool[query_idx]))

        X_rolling, y_rolling = np.concatenate((X_rolling, new_samples), axis=0), np.concatenate((y_rolling, new_labels))
        learner.fit(X_rolling, y_rolling)
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
    return perf_hist

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
uniform_perf_hist = active_learning_procedure_vae(uniform,
                                              X_test,
                                              y_test,
                                              X_pool,
                                              y_pool,
                                              X_initial,
                                              y_initial,
                                              estimator,
                                             n_instances=100)
# save_list(uniform_perf_hist, "uniform_perf_hist")

Accuracy after query 1: 0.2139
Accuracy after query 2: 0.2224
Accuracy after query 3: 0.2912
Accuracy after query 4: 0.3201
Accuracy after query 5: 0.3286
Accuracy after query 6: 0.3105
Accuracy after query 7: 0.2769
Accuracy after query 8: 0.3504
Accuracy after query 9: 0.3484
Accuracy after query 10: 0.2954
Accuracy after query 11: 0.3288
Accuracy after query 12: 0.3437
Accuracy after query 13: 0.3315
Accuracy after query 14: 0.3588
Accuracy after query 15: 0.3621
Accuracy after query 16: 0.3506
Accuracy after query 17: 0.3073
Accuracy after query 18: 0.3836
Accuracy after query 19: 0.3587
Accuracy after query 20: 0.3523
Accuracy after query 21: 0.3484
Accuracy after query 22: 0.3709
Accuracy after query 23: 0.3934
Accuracy after query 24: 0.3699
Accuracy after query 25: 0.3885
Accuracy after query 26: 0.3785
Accuracy after query 27: 0.3944
Accuracy after query 28: 0.3756
Accuracy after query 29: 0.3909
Accuracy after query 30: 0.3906
Accuracy after query 31: 0.4010
Accuracy after qu

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
bald_perf_hist = active_learning_procedure_vae(bald,
                                              X_test,
                                              y_test,
                                              X_pool,
                                              y_pool,
                                              X_initial,
                                              y_initial,
                                              estimator,
                                             n_instances=100)
save_list(bald_perf_hist, "bald_perf_hist")

Accuracy after query 1: 0.2103
Accuracy after query 2: 0.2152
Accuracy after query 3: 0.2865
Accuracy after query 4: 0.3021
Accuracy after query 5: 0.3024
Accuracy after query 6: 0.3329
Accuracy after query 7: 0.3369
Accuracy after query 8: 0.3146
Accuracy after query 9: 0.3347
Accuracy after query 10: 0.3345
Accuracy after query 11: 0.3001
Accuracy after query 12: 0.3361
Accuracy after query 13: 0.3358
Accuracy after query 14: 0.3449
Accuracy after query 15: 0.3514
Accuracy after query 16: 0.3656
Accuracy after query 17: 0.3535
Accuracy after query 18: 0.3547
Accuracy after query 19: 0.3303
Accuracy after query 20: 0.3768
Accuracy after query 21: 0.3762
Accuracy after query 22: 0.3669
Accuracy after query 23: 0.3568
Accuracy after query 24: 0.3909
Accuracy after query 25: 0.3662
Accuracy after query 26: 0.3672
Accuracy after query 27: 0.3901
Accuracy after query 28: 0.3929
Accuracy after query 29: 0.3590
Accuracy after query 30: 0.3849
Accuracy after query 31: 0.3870
Accuracy after qu

In [18]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(lenet,
                                max_epochs=50,
                                batch_size=100,
                                lr=1.0,
                                optimizer=torch.optim.Adadelta,
                                optimizer__rho=0.9,
                                optimizer__eps=1e-6,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
entropy_perf_hist = active_learning_procedure_vae(max_entropy,
                                              X_test,
                                              y_test,
                                              X_pool,
                                              y_pool,
                                              X_initial,
                                              y_initial,
                                              estimator,
                                             n_instances=100)
save_list(entropy_perf_hist, "entropy_perf_hist")

Accuracy after query 1: 0.2386
Accuracy after query 2: 0.2253
Accuracy after query 3: 0.2444
Accuracy after query 4: 0.2487
Accuracy after query 5: 0.2692
Accuracy after query 6: 0.2924
Accuracy after query 7: 0.3336
Accuracy after query 8: 0.3301
Accuracy after query 9: 0.3145
Accuracy after query 10: 0.3087
Accuracy after query 11: 0.3456
Accuracy after query 12: 0.3345
Accuracy after query 13: 0.3520
Accuracy after query 14: 0.3542
Accuracy after query 15: 0.3103
Accuracy after query 16: 0.3313
Accuracy after query 17: 0.3536
Accuracy after query 18: 0.3294
Accuracy after query 19: 0.3010
Accuracy after query 20: 0.3529
Accuracy after query 21: 0.3613
Accuracy after query 22: 0.3610
Accuracy after query 23: 0.3588
Accuracy after query 24: 0.3798
Accuracy after query 25: 0.3680
Accuracy after query 26: 0.3481
Accuracy after query 27: 0.3732
Accuracy after query 28: 0.3783
Accuracy after query 29: 0.3682
Accuracy after query 30: 0.4093
Accuracy after query 31: 0.3612
Accuracy after qu