This is an implementation of the paper Deep Bayesian Active Learning with Image Data using PyTorch and modAL. 

modAL is an active learning framework for Python3, designed with modularity, flexibility and extensibility in mind. Built on top of scikit-learn, it allows you to rapidly create active learning workflows with nearly complete freedom. What is more, you can easily replace parts with your custom built solutions, allowing you to design novel algorithms with ease.

Since modAL only supports sklearn models, we will also use [skorch](https://skorch.readthedocs.io/en/stable/), a scikit-learn compatible neural network library that wraps PyTorch. 

In [1]:
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from skorch import NeuralNetClassifier
from modAL.models import ActiveLearner
from batchbald_redux import batchbald

### architecture of the network we will be using

We will use the architecture described in the paper.

In [2]:
class MLP_REG(nn.Module):
    def __init__(self,):
        super(MLP_REG, self).__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 256),
            nn.Dropout(p=0.25),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.Dropout(p=0.25),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(128, 10),)
        
    def forward(self, x):
        return self.layers(x)


### read training data

In [3]:
mnist_train = MNIST('.', train=True, download=True, transform=ToTensor())
mnist_test  = MNIST('.', train=False,download=True, transform=ToTensor())
traindataloader = DataLoader(mnist_train, shuffle=True, batch_size=60000)
testdataloader  = DataLoader(mnist_test , shuffle=True, batch_size=10000)
X_train, y_train = next(iter(traindataloader))
X_test , y_test  = next(iter(testdataloader))

X_train = X_train.detach().cpu().numpy()
y_train = y_train.detach().cpu().numpy()

X_test = X_test.detach().cpu().numpy()
y_test = y_test.detach().cpu().numpy()

### preprocessing

In [4]:
#X_train = X_train.reshape(60000, 1, 28, 28)
#X_test = X_test.reshape(10000, 1, 28, 28)

### initial labelled data
We initialize the labelled set with 20 balanced randomly sampled examples

In [5]:
initial_idx = np.array([],dtype=int)
for i in range(10):
    idx = np.random.choice(np.where(y_train==i)[0], size=2, replace=False)
    initial_idx = np.concatenate((initial_idx, idx))

X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]

### initial unlabelled pool

In [6]:
X_pool = np.delete(X_train, initial_idx, axis=0)
y_pool = np.delete(y_train, initial_idx, axis=0)

## Query Strategies

### Uniform
All the acquisition function we will use will be compared to the uniform acquisition function $\mathbb{U}_{[0,1]}$ which will be our baseline that we would like to beat.

In [7]:
def uniform(learner, X, n_instances=1):
    query_idx = np.random.choice(range(len(X)), size=n_instances, replace=False)
    return query_idx, X[query_idx]

### Entropy
Our first acquisition function is the entropy:
$$ \mathbb{H} = - \sum_{c} p_c \log(p_c)$$
where $p_c$ is the probability predicted for class c. This is approximated by:
\begin{align}
p_c &= \frac{1}{T} \sum_t p_{c}^{(t)} 
\end{align}
where $p_{c}^{t}$ is the probability predicted for class c at the t th feedforward pass.

In [8]:
def max_entropy(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(100)])
    pc = outputs.mean(axis=0)
    acquisition = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]

In [9]:
def bald(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    with torch.no_grad():
        outputs = np.stack([torch.softmax(learner.estimator.forward(X[random_subset], training=True),dim=-1).cpu().numpy()
                            for t in range(100)])
    pc = outputs.mean(axis=0)
    H   = (-pc*np.log(pc + 1e-10)).sum(axis=-1)
    E_H = - np.mean(np.sum(outputs * np.log(outputs + 1e-10), axis=-1), axis=0)  # [batch size]
    acquisition = H - E_H
    idx = (-acquisition).argsort()[:n_instances]
    query_idx = random_subset[idx]
    return query_idx, X[query_idx]    

In [10]:
def batch_bald(learner, X, n_instances=1, T=100):
    random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    X_random = X[random_subset]
    
    batch_size = 128
    #random_subset = np.random.choice(range(len(X)), size=2000, replace=False)
    
    use_cuda = torch.cuda.is_available()
    device = "cuda" if use_cuda else "cpu"
        
    # Acquire pool predictions
    N = len(X_random)
    logits_N_K_C = torch.empty((N, T, 10), dtype=torch.double, pin_memory=use_cuda)

    with torch.no_grad():
        for i in range(0, N, batch_size):
            lower = i
            upper = min(i + batch_size, N)
            outputs = np.stack([torch.softmax(learner.estimator.forward(X_random[lower:upper], training=True),dim=-1).cpu().numpy()
                            for t in range(100)])
            #pc = outputs.mean(axis=0)
            outputs = outputs.reshape([upper-lower, 100, 10])
            
            logits_N_K_C[lower:upper].copy_(torch.from_numpy(outputs).double())

    with torch.no_grad():
        candidate_batch = batchbald.get_batchbald_batch(
            logits_N_K_C, n_instances, 2000, dtype=torch.double, device=device
        )
    query_idx = random_subset[candidate_batch.indices]
    return query_idx, X[query_idx]

### Active Learning Procedure

In [11]:
def active_learning_procedure(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=20,
                              n_instances=100):
    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)
        learner.teach(X_pool[query_idx], y_pool[query_idx])
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
    return perf_hist

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(MLP_REG,
                                max_epochs=200,
                                batch_size=128,
                                lr=0.001,
                                optimizer=torch.optim.Adam,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
batchbald_perf_hist = active_learning_procedure(batch_bald,
                                                X_test,
                                                y_test,
                                                X_pool,
                                                y_pool,
                                                X_initial,
                                                y_initial,
                                                estimator,)

Conditional Entropy:   0%|          | 0/2000 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/100 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

Accuracy after query 1: 0.7728


Conditional Entropy:   0%|          | 0/2000 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/100 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

Accuracy after query 2: 0.8248


Conditional Entropy:   0%|          | 0/2000 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/100 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

SampledJointEntropy.compute_batch:   0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(MLP_REG,
                                max_epochs=200,
                                batch_size=128,
                                lr=0.001,
                                optimizer=torch.optim.Adam,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
entropy_perf_hist = active_learning_procedure(max_entropy,
                                              X_test,
                                              y_test,
                                              X_pool,
                                              y_pool,
                                              X_initial,
                                              y_initial,
                                              estimator,)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(MLP_REG,
                                max_epochs=200,
                                batch_size=128,
                                lr=0.001,
                                optimizer=torch.optim.Adam,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
bald_perf_hist = active_learning_procedure(bald,
                                           X_test,
                                           y_test,
                                           X_pool,
                                           y_pool,
                                           X_initial,
                                           y_initial,
                                           estimator,)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
estimator = NeuralNetClassifier(MLP_REG,
                                max_epochs=200,
                                batch_size=128,
                                lr=0.001,
                                optimizer=torch.optim.Adam,
                                criterion=torch.nn.CrossEntropyLoss,
                                train_split=None,
                                verbose=0,
                                device=device)
uniform_perf_hist = active_learning_procedure(uniform,
                                              X_test,
                                              y_test,
                                              X_pool,
                                              y_pool,
                                              X_initial,
                                              y_initial,
                                              estimator,)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
plt.plot(entropy_perf_hist, label="entropy")
plt.plot(bald_perf_hist, label="bald")
plt.plot(uniform_perf_hist, label="uniform")
plt.ylim([0.7,1])
plt.legend()