In [1]:
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from skorch import NeuralNetClassifier
from modAL.models import ActiveLearner
from batchbald_redux import batchbald
from acquisition_functions import *
import os
from keras.datasets import mnist
from models import MLP
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
MAX_EPOCHS = 200
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EXPERIMENT_COUNT = 3
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
results_path = 'results'

if not os.path.exists(results_path):
    os.makedirs(results_path)

ACQ_FUNCS = {
    "var_ratios": var_ratios,
    "mean_std": mean_std,
    "max_entropy": max_entropy,
    "bald": bald,
    "uniform": uniform,
    "batch_bald": batch_bald
}

### read training data

In [5]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(60000, 28, 28, 1).astype('float32') / 255.
X_test = X_test.reshape(10000, 28, 28, 1).astype('float32') / 255.

### Active Learning Procedure

In [6]:
def active_learning_procedure(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=98,
                              n_instances=10):
    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    active_pool_size = [len(X_initial)]
    pool_size = len(X_initial)
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)
        learner.teach(X_pool[query_idx], y_pool[query_idx])
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        pool_size = pool_size + n_instances
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
        active_pool_size.append(pool_size)
    return perf_hist, active_pool_size

In [7]:
for exp_iter in range(EXPERIMENT_COUNT):
    np.random.seed(exp_iter)
    initial_idx = np.array([],dtype=int)
    for i in range(10):
        idx = np.random.choice(np.where(y_train==i)[0], size=2, replace=False)
        initial_idx = np.concatenate((initial_idx, idx))
    for func_name, acquisition_func in ACQ_FUNCS.items():  
        X_initial = X_train[initial_idx]
        y_initial = y_train[initial_idx]

        X_pool = np.delete(X_train, initial_idx, axis=0)
        y_pool = np.delete(y_train, initial_idx, axis=0)

        model = MLP().to(DEVICE)

        estimator = NeuralNetClassifier(model,
                                      max_epochs=MAX_EPOCHS,
                                      batch_size=BATCH_SIZE,
                                      lr=LEARNING_RATE,
                                      optimizer=torch.optim.Adam,
                                      criterion=torch.nn.CrossEntropyLoss,
                                      train_split=None,
                                      verbose=0,
                                      device=DEVICE)



        acc_arr, dataset_size_arr = active_learning_procedure(acquisition_func,
                                                          X_test,
                                                          y_test,
                                                          X_pool,
                                                          y_pool,
                                                          X_initial,
                                                          y_initial,
                                                          estimator,)
        file_name = os.path.join(results_path, func_name + "_exp_" + str(exp_iter) + ".npy")
        np.save(file_name, (acc_arr, dataset_size_arr))

(2000,)
Accuracy after query 1: 0.5546
(2000,)
Accuracy after query 2: 0.6045
(2000,)
Accuracy after query 3: 0.6419
(2000,)
Accuracy after query 4: 0.6607
(2000,)
Accuracy after query 5: 0.7038
(2000,)
Accuracy after query 6: 0.7343
(2000,)
Accuracy after query 7: 0.7326
(2000,)
Accuracy after query 8: 0.7468
(2000,)
Accuracy after query 9: 0.7625
(2000,)
Accuracy after query 10: 0.7821
(2000,)
Accuracy after query 11: 0.7809
(2000,)
Accuracy after query 12: 0.7833
(2000,)
Accuracy after query 13: 0.8014
(2000,)
Accuracy after query 14: 0.8021
(2000,)
Accuracy after query 15: 0.8149
(2000,)
Accuracy after query 16: 0.8204
(2000,)
Accuracy after query 17: 0.8193
(2000,)
Accuracy after query 18: 0.8116
(2000,)
Accuracy after query 19: 0.8188
(2000,)
Accuracy after query 20: 0.8225
(2000,)
Accuracy after query 21: 0.8160
(2000,)
Accuracy after query 22: 0.8296
(2000,)
Accuracy after query 23: 0.8174
(2000,)
Accuracy after query 24: 0.8458
(2000,)
Accuracy after query 25: 0.8395
(2000,)
A