In [1]:
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
# from torchvision.datasets import CIFAR10
from skorch import NeuralNetClassifier
from modAL.models import ActiveLearner
from batchbald_redux import batchbald
from acquisition_functions import *
import os
from keras.datasets import cifar10

In [2]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
MAX_EPOCHS = 50
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EXPERIMENT_COUNT = 3
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
results_path = 'results\\cifar10_dbal'

if not os.path.exists(results_path):
    os.makedirs(results_path)

ACQ_FUNCS = {
    "var_ratios": var_ratios,
    "mean_std": mean_std,
    "max_entropy": max_entropy,
    "bald": bald,
    "uniform": uniform,
#     "batch_bald": batch_bald
}

### architecture of the network we will be using

We will use the architecture described in the paper.

In [4]:
class MLP_REG(nn.Module):
    def __init__(self):
        super(MLP_REG, self).__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(3072, 512),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 10),)
        
    def forward(self, x):
        return self.layers(x)


### read training data

In [5]:
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

X_train = X_train.reshape(50000, 32, 32, 3).astype('float32') / 255.
X_test = X_test.reshape(10000, 32, 32, 3).astype('float32') / 255.

X_train = X_train.reshape(X_train.shape[0], -1)
X_test = X_test.reshape(X_test.shape[0], -1)

In [6]:
y_train = y_train.reshape((y_train.shape[0], ))
y_test = y_test.reshape((y_test.shape[0], ))

### Active Learning Procedure

In [7]:
def active_learning_procedure(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=98,
                              n_instances=10):
    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    active_pool_size = [len(X_initial)]
    pool_size = len(X_initial)
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)
        learner.teach(X_pool[query_idx], y_pool[query_idx])
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        pool_size = pool_size + n_instances
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
        active_pool_size.append(pool_size)
    return perf_hist, active_pool_size

In [8]:
for exp_iter in range(EXPERIMENT_COUNT):
    np.random.seed(exp_iter)
    initial_idx = np.array([],dtype=int)
    for i in range(10):
        idx = np.random.choice(np.where(y_train==i)[0], size=2, replace=False)
        initial_idx = np.concatenate((initial_idx, idx))
    
    for func_name, acquisition_func in ACQ_FUNCS.items():  
        X_initial = X_train[initial_idx]
        y_initial = y_train[initial_idx]

        X_pool = np.delete(X_train, initial_idx, axis=0)
        y_pool = np.delete(y_train, initial_idx, axis=0)

        model = MLP_REG().to(DEVICE)

        estimator = NeuralNetClassifier(model,
                                      max_epochs=MAX_EPOCHS,
                                      batch_size=BATCH_SIZE,
                                      lr=LEARNING_RATE,
                                      optimizer=torch.optim.Adam,
                                      criterion=torch.nn.CrossEntropyLoss,
                                      train_split=None,
                                      verbose=0,
                                      device=DEVICE)



        acc_arr, dataset_size_arr = active_learning_procedure(acquisition_func,
                                                          X_test,
                                                          y_test,
                                                          X_pool,
                                                          y_pool,
                                                          X_initial,
                                                          y_initial,
                                                          estimator,)
        file_name = os.path.join(results_path, func_name + "_exp_" + str(exp_iter) + ".npy")
        np.save(file_name, (acc_arr, dataset_size_arr))

Accuracy after query 1: 0.1782
Accuracy after query 2: 0.1839
Accuracy after query 3: 0.1881
Accuracy after query 4: 0.2103
Accuracy after query 5: 0.2098
Accuracy after query 6: 0.2196
Accuracy after query 7: 0.2126
Accuracy after query 8: 0.2274
Accuracy after query 9: 0.2361
Accuracy after query 10: 0.2341
Accuracy after query 11: 0.2313
Accuracy after query 12: 0.2303
Accuracy after query 13: 0.2285
Accuracy after query 14: 0.2388
Accuracy after query 15: 0.2519
Accuracy after query 16: 0.2471
Accuracy after query 17: 0.2590
Accuracy after query 18: 0.2592
Accuracy after query 19: 0.2639
Accuracy after query 20: 0.2565
Accuracy after query 21: 0.2525
Accuracy after query 22: 0.2608
Accuracy after query 23: 0.2595
Accuracy after query 24: 0.2608
Accuracy after query 25: 0.2620
Accuracy after query 26: 0.2512
Accuracy after query 27: 0.2689
Accuracy after query 28: 0.2740
Accuracy after query 29: 0.2662
Accuracy after query 30: 0.2692
Accuracy after query 31: 0.2673
Accuracy after qu

Accuracy after query 61: 0.2474
Accuracy after query 62: 0.2427
Accuracy after query 63: 0.2313
Accuracy after query 64: 0.2075
Accuracy after query 65: 0.2392
Accuracy after query 66: 0.1883
Accuracy after query 67: 0.2235
Accuracy after query 68: 0.2475
Accuracy after query 69: 0.2440
Accuracy after query 70: 0.2278
Accuracy after query 71: 0.2473
Accuracy after query 72: 0.2525
Accuracy after query 73: 0.2490
Accuracy after query 74: 0.2487
Accuracy after query 75: 0.2206
Accuracy after query 76: 0.2306
Accuracy after query 77: 0.2301
Accuracy after query 78: 0.2267
Accuracy after query 79: 0.2312
Accuracy after query 80: 0.2262
Accuracy after query 81: 0.2183
Accuracy after query 82: 0.1908
Accuracy after query 83: 0.2281
Accuracy after query 84: 0.2042
Accuracy after query 85: 0.2129
Accuracy after query 86: 0.2301
Accuracy after query 87: 0.1975
Accuracy after query 88: 0.1634
Accuracy after query 89: 0.1812
Accuracy after query 90: 0.1841
Accuracy after query 91: 0.1925
Accuracy

Accuracy after query 23: 0.2743
Accuracy after query 24: 0.2779
Accuracy after query 25: 0.2702
Accuracy after query 26: 0.2727
Accuracy after query 27: 0.2614
Accuracy after query 28: 0.2680
Accuracy after query 29: 0.2732
Accuracy after query 30: 0.2715
Accuracy after query 31: 0.2800
Accuracy after query 32: 0.2738
Accuracy after query 33: 0.2756
Accuracy after query 34: 0.2792
Accuracy after query 35: 0.2784
Accuracy after query 36: 0.2743
Accuracy after query 37: 0.2713
Accuracy after query 38: 0.2633
Accuracy after query 39: 0.2724
Accuracy after query 40: 0.2598
Accuracy after query 41: 0.2683
Accuracy after query 42: 0.2626
Accuracy after query 43: 0.2559
Accuracy after query 44: 0.2641
Accuracy after query 45: 0.2605
Accuracy after query 46: 0.2662
Accuracy after query 47: 0.2657
Accuracy after query 48: 0.2670
Accuracy after query 49: 0.2657
Accuracy after query 50: 0.2685
Accuracy after query 51: 0.2613
Accuracy after query 52: 0.2552
Accuracy after query 53: 0.2557
Accuracy

Accuracy after query 83: 0.1915
Accuracy after query 84: 0.1819
Accuracy after query 85: 0.1887
Accuracy after query 86: 0.1682
Accuracy after query 87: 0.2082
Accuracy after query 88: 0.1639
Accuracy after query 89: 0.1675
Accuracy after query 90: 0.1771
Accuracy after query 91: 0.1581
Accuracy after query 92: 0.1772
Accuracy after query 93: 0.1811
Accuracy after query 94: 0.1874
Accuracy after query 95: 0.1825
Accuracy after query 96: 0.1801
Accuracy after query 97: 0.1878
Accuracy after query 98: 0.1809
Accuracy after query 1: 0.1764
Accuracy after query 2: 0.2261
Accuracy after query 3: 0.2395
Accuracy after query 4: 0.2382
Accuracy after query 5: 0.2423
Accuracy after query 6: 0.2578
Accuracy after query 7: 0.2434
Accuracy after query 8: 0.2486
Accuracy after query 9: 0.2481
Accuracy after query 10: 0.2532
Accuracy after query 11: 0.2640
Accuracy after query 12: 0.2653
Accuracy after query 13: 0.2511
Accuracy after query 14: 0.2652
Accuracy after query 15: 0.2675
Accuracy after qu

Accuracy after query 45: 0.2765
Accuracy after query 46: 0.2823
Accuracy after query 47: 0.2905
Accuracy after query 48: 0.2864
Accuracy after query 49: 0.2881
Accuracy after query 50: 0.2770
Accuracy after query 51: 0.2796
Accuracy after query 52: 0.2729
Accuracy after query 53: 0.2643
Accuracy after query 54: 0.2704
Accuracy after query 55: 0.2832
Accuracy after query 56: 0.2828
Accuracy after query 57: 0.2901
Accuracy after query 58: 0.2823
Accuracy after query 59: 0.2862
Accuracy after query 60: 0.2800
Accuracy after query 61: 0.2833
Accuracy after query 62: 0.2864
Accuracy after query 63: 0.2770
Accuracy after query 64: 0.2834
Accuracy after query 65: 0.2871
Accuracy after query 66: 0.2829
Accuracy after query 67: 0.2823
Accuracy after query 68: 0.2849
Accuracy after query 69: 0.2842
Accuracy after query 70: 0.2778
Accuracy after query 71: 0.2771
Accuracy after query 72: 0.2805
Accuracy after query 73: 0.2697
Accuracy after query 74: 0.2712
Accuracy after query 75: 0.2693
Accuracy

Accuracy after query 7: 0.2098
Accuracy after query 8: 0.2277
Accuracy after query 9: 0.2097
Accuracy after query 10: 0.2265
Accuracy after query 11: 0.2291
Accuracy after query 12: 0.2266
Accuracy after query 13: 0.2424
Accuracy after query 14: 0.2239
Accuracy after query 15: 0.2289
Accuracy after query 16: 0.2505
Accuracy after query 17: 0.2518
Accuracy after query 18: 0.2465
Accuracy after query 19: 0.2545
Accuracy after query 20: 0.2622
Accuracy after query 21: 0.2458
Accuracy after query 22: 0.2635
Accuracy after query 23: 0.2617
Accuracy after query 24: 0.2471
Accuracy after query 25: 0.2444
Accuracy after query 26: 0.2549
Accuracy after query 27: 0.2580
Accuracy after query 28: 0.2655
Accuracy after query 29: 0.2610
Accuracy after query 30: 0.2634
Accuracy after query 31: 0.2628
Accuracy after query 32: 0.2776
Accuracy after query 33: 0.2806
Accuracy after query 34: 0.2788
Accuracy after query 35: 0.2700
Accuracy after query 36: 0.2855
Accuracy after query 37: 0.2777
Accuracy af