In [0]:
%tensorflow_version 1.x

import numpy as np
import keras
from keras.datasets import mnist
import sys
from scipy.stats import entropy
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten, SpatialDropout2D
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.regularizers import l2
from keras import backend as K

from google.colab import drive
drive.mount("/content/gdrive")

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
def predict_with_uncertainty(f, x, n_iter=100):
    """Function generating non-deterministic predictions using MC dropout and returning the mean of these predictions
    Adapted from: https://stackoverflow.com/questions/43529931/how-to-calculate-prediction-uncertainty-using-keras
    #Arguments:
        f: function mapping model input and Keras backend learning_phase flag to model output
        x: input
        n_iter: number of repreated MC dropout predictions per point
    #Returns:
        Mean of MC dropout predictions
    """
    result = np.zeros((n_iter,x.shape[0], 2))
    for i in range(n_iter):
        predictions = np.array(f((x, 1))[0])
        result[i,:, :] = predictions
    prediction = result.mean(axis=0)
    return prediction

#Used for making repeated pool predictions
def predict_pool_with_uncertainty(f, x, n_iter=50):
    """Function generating and returning non-deterministic predictions using MC dropout
    Adapted from: https://stackoverflow.com/questions/43529931/how-to-calculate-prediction-uncertainty-using-keras
    #Arguments:
        f: function mapping model input and Keras backend learning_phase flag to model output
        x: input
        n_iter: number of repreated MC dropout predictions per point
    #Returns:
        All MC dropout predictions
    """
    result = np.zeros((n_iter,x.shape[0], 2))
    for i in range(n_iter):
        predictions = np.array(f((x, 1))[0])
        result[i,:, :] = predictions
    return result


def initialize_model (train_data_indices):
  """Used to initialize a Keras model in each active learning iteration; enables model retraining from scratch in each iteration
  #Arguments:
    train_data_indices: indices of training data within X_train_All (training set size affects weight decay)
  #Returns:
    A new initialized Keras model
  """
  X_train = np.expand_dims(X_train_All[train_data_indices], axis=1)
  y_train = y_train_All[train_data_indices]
  y_train[y_train==7] = 1
  y_train[y_train==9] = 0
  y_train = keras.utils.to_categorical(y_train, num_classes=2)
  train_size = y_train.shape[0]
  Weight_Decay = 0.0005/train_size
  dropout_prob = 0.30
  batch_size=128
  nb_filters = 35
  nb_pool = 3
  nb_conv = 4
  img_rows = img_cols = 28
  nb_classes = 2
  model = Sequential()
  model.add(Convolution2D(nb_filters, nb_conv, strides=1, data_format="channels_first", input_shape=(1, img_rows, img_cols)))
  model.add(Activation('relu'))
  model.add(Convolution2D(nb_filters, nb_conv, data_format="channels_first", strides=2))
  model.add(Activation('relu'))
  model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool), data_format="channels_first"))
  model.add(Dropout(dropout_prob))
  model.add(Flatten())
  model.add(Dense(128, W_regularizer=l2(Weight_Decay)))
  model.add(Activation('relu'))
  model.add(Dropout(dropout_prob))
  model.add(Dense(nb_classes, W_regularizer=l2(Weight_Decay)))
  model.add(Activation('softmax'))
  model.compile(loss='categorical_crossentropy', optimizer='adam')
  return model

def run_model (train_data_indices):
  """Trains a Keras model with the training points specified by train_data_indices and evaluates model on test data
  #Arguments:
      train_data_indices: indices of current training points within X_train_All
  #Returns:
      Test accuracy
  """
  X_train = np.expand_dims(X_train_All[train_data_indices], axis=1)
  y_train = y_train_All[train_data_indices]
  y_train[y_train==7] = 1
  y_train[y_train==9]=0
  y_train = keras.utils.to_categorical(y_train, num_classes=2)
  model = initialize_model(train_data_indices)
  model.fit(X_train, y_train, epochs=300, batch_size=128, verbose=0)
  f = K.function([model.layers[0].input, K.learning_phase()],[model.layers[-1].output])
  y_test_output = predict_with_uncertainty(f, X_test, n_iter=100)
  y_test_predictions = np.argmax(y_test_output, axis=1)
  return np.sum(y_test_predictions==y_test_original)/(y_test_original.shape[0])



In [0]:
#Active learning parameters/settings
dropout_prob = 0.30
num_experiments = 3
num_acquisitions = 1000
pool_sample_size = 2500
num_masks = 50
batch_size = 1


In [0]:
#Loading data
data_path = "/content/gdrive/My Drive/FINAL_PAPER_ACTIVE_LEARNING_EXP/MNIST/"
starting_ind_path = "/content/gdrive/My Drive/FINAL_PAPER_ACTIVE_LEARNING_EXP/MNIST/Binary_7_9_AL_Scripts/"
results_path = "/content/gdrive/My Drive/FINAL_PAPER_ACTIVE_LEARNING_EXP/MNIST/Binary_7_9_AL_Results/"
train_data = np.loadtxt(data_path + "mnist_train.csv", 
			delimiter=",")
test_data = np.loadtxt(data_path + "mnist_test.csv", 
			delimiter=",") 
y_train_All = train_data[:,0]
y_test = test_data[:,0]
X_train_All = train_data[:,1:].reshape((60000,28,28))
X_test = test_data[:,1:].reshape((10000,28,28)) 
train_ind = np.concatenate((np.argwhere(y_train_All==7), np.argwhere(y_train_All==9))).flatten()
test_ind = np.concatenate((np.argwhere(y_test==7), np.argwhere(y_test==9))).flatten()
y_test = y_test[test_ind]
y_test[y_test==9] = 0
y_test[y_test==7]= 1
y_test_original = y_test
y_test = keras.utils.to_categorical(y_test, num_classes=2)
X_test = np.expand_dims(X_test[test_ind], axis=1)

#Iterating across experiments, each of which begins with a different training set (that is balanced across classes)
#Note: code is currently set up to *resume* an active learning experiment
#To start an active learning experiment from scratch, several lines below that are commented out should be uncommented, and vice versa

for e in range(0,1):
  acc_file = "Rand_BS"+str(batch_size)+"_Acc_Ind" + str(e+1) + ".npy"
  ind_file = "Rand_BS"+str(batch_size)+"_Ind_Ind" + str(e+1) + ".npy"
  #exp_acc = []
  exp_acc = list(np.load(results_path+acc_file))
  #train_data_indices = list(np.load(starting_ind_path + 'trainindices' + str(e+1) + '.npy'))
  train_data_indices = list(np.load(results_path+ind_file))
  num_acquisitions = num_acquisitions - batch_size * (len(exp_acc) - 1)
  pool_indices = [i for i in train_ind if i not in train_data_indices]
  #Looping over active learning iterations
  for acq in range(num_acquisitions//batch_size + 1):
    #Evaluating model on test data
    model_results = run_model(train_data_indices)
    if acq != 0:
      exp_acc.append(model_results)
    #Random acquisition of new points
    all_acq_ind = np.random.choice(pool_indices, batch_size, replace=False)
    for acq_ind in all_acq_ind:
      train_data_indices.append(acq_ind)
      pool_indices.remove(acq_ind)
    #Saving results
    np.save(results_path+acc_file, np.array(exp_acc))
    np.save(results_path+ind_file, np.array(train_data_indices))
    print('Exp ' + str(e+1) + ', Number elapsed iterations: ' + str(len(exp_acc)) + ", last acc: " + str(exp_acc[-1]))


Instructions for updating:
If using Keras pass *_constraint arguments to layers.






Exp 1, Number elapsed iterations: 822, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 823, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 824, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 825, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 826, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 827, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 828, last acc: 0.9916543937162494




Exp 1, Number elapsed iterations: 829, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 830, last acc: 0.9891998036327934




Exp 1, Number elapsed iterations: 831, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 832, last acc: 0.9877270495827197




Exp 1, Number elapsed iterations: 833, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 834, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 835, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 836, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 837, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 838, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 839, last acc: 0.9891998036327934




Exp 1, Number elapsed iterations: 840, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 841, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 842, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 843, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 844, last acc: 0.9891998036327934




Exp 1, Number elapsed iterations: 845, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 846, last acc: 0.9945999018163967




Exp 1, Number elapsed iterations: 847, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 848, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 849, last acc: 0.9950908198330879




Exp 1, Number elapsed iterations: 850, last acc: 0.9965635738831615




Exp 1, Number elapsed iterations: 851, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 852, last acc: 0.9896907216494846




Exp 1, Number elapsed iterations: 853, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 854, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 855, last acc: 0.9872361315660285




Exp 1, Number elapsed iterations: 856, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 857, last acc: 0.9916543937162494




Exp 1, Number elapsed iterations: 858, last acc: 0.9945999018163967




Exp 1, Number elapsed iterations: 859, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 860, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 861, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 862, last acc: 0.9916543937162494




Exp 1, Number elapsed iterations: 863, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 864, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 865, last acc: 0.9882179675994109




Exp 1, Number elapsed iterations: 866, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 867, last acc: 0.9872361315660285




Exp 1, Number elapsed iterations: 868, last acc: 0.9847815414825725




Exp 1, Number elapsed iterations: 869, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 870, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 871, last acc: 0.9891998036327934




Exp 1, Number elapsed iterations: 872, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 873, last acc: 0.9916543937162494




Exp 1, Number elapsed iterations: 874, last acc: 0.9896907216494846




Exp 1, Number elapsed iterations: 875, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 876, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 877, last acc: 0.9847815414825725




Exp 1, Number elapsed iterations: 878, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 879, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 880, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 881, last acc: 0.9896907216494846




Exp 1, Number elapsed iterations: 882, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 883, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 884, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 885, last acc: 0.9896907216494846




Exp 1, Number elapsed iterations: 886, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 887, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 888, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 889, last acc: 0.9862542955326461




Exp 1, Number elapsed iterations: 890, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 891, last acc: 0.9882179675994109




Exp 1, Number elapsed iterations: 892, last acc: 0.9872361315660285




Exp 1, Number elapsed iterations: 893, last acc: 0.9950908198330879




Exp 1, Number elapsed iterations: 894, last acc: 0.9882179675994109




Exp 1, Number elapsed iterations: 895, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 896, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 897, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 898, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 899, last acc: 0.9877270495827197




Exp 1, Number elapsed iterations: 900, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 901, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 902, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 903, last acc: 0.9916543937162494




Exp 1, Number elapsed iterations: 904, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 905, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 906, last acc: 0.9945999018163967




Exp 1, Number elapsed iterations: 907, last acc: 0.9882179675994109




Exp 1, Number elapsed iterations: 908, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 909, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 910, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 911, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 912, last acc: 0.9896907216494846




Exp 1, Number elapsed iterations: 913, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 914, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 915, last acc: 0.9950908198330879




Exp 1, Number elapsed iterations: 916, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 917, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 918, last acc: 0.9891998036327934




Exp 1, Number elapsed iterations: 919, last acc: 0.9862542955326461




Exp 1, Number elapsed iterations: 920, last acc: 0.9950908198330879




Exp 1, Number elapsed iterations: 921, last acc: 0.9891998036327934




Exp 1, Number elapsed iterations: 922, last acc: 0.9891998036327934




Exp 1, Number elapsed iterations: 923, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 924, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 925, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 926, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 927, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 928, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 929, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 930, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 931, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 932, last acc: 0.9916543937162494




Exp 1, Number elapsed iterations: 933, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 934, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 935, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 936, last acc: 0.9945999018163967




Exp 1, Number elapsed iterations: 937, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 938, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 939, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 940, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 941, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 942, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 943, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 944, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 945, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 946, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 947, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 948, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 949, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 950, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 951, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 952, last acc: 0.9950908198330879




Exp 1, Number elapsed iterations: 953, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 954, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 955, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 956, last acc: 0.9945999018163967




Exp 1, Number elapsed iterations: 957, last acc: 0.9936180657830143




Exp 1, Number elapsed iterations: 958, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 959, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 960, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 961, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 962, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 963, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 964, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 965, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 966, last acc: 0.9896907216494846




Exp 1, Number elapsed iterations: 967, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 968, last acc: 0.9887088856161022




Exp 1, Number elapsed iterations: 969, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 970, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 971, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 972, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 973, last acc: 0.9950908198330879




Exp 1, Number elapsed iterations: 974, last acc: 0.9882179675994109




Exp 1, Number elapsed iterations: 975, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 976, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 977, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 978, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 979, last acc: 0.9945999018163967




Exp 1, Number elapsed iterations: 980, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 981, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 982, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 983, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 984, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 985, last acc: 0.9945999018163967




Exp 1, Number elapsed iterations: 986, last acc: 0.9896907216494846




Exp 1, Number elapsed iterations: 987, last acc: 0.9916543937162494




Exp 1, Number elapsed iterations: 988, last acc: 0.9867452135493373




Exp 1, Number elapsed iterations: 989, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 990, last acc: 0.9926362297496318




Exp 1, Number elapsed iterations: 991, last acc: 0.9921453117329406




Exp 1, Number elapsed iterations: 992, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 993, last acc: 0.9901816396661758




Exp 1, Number elapsed iterations: 994, last acc: 0.9960726558664703




Exp 1, Number elapsed iterations: 995, last acc: 0.993127147766323




Exp 1, Number elapsed iterations: 996, last acc: 0.9882179675994109




Exp 1, Number elapsed iterations: 997, last acc: 0.9891998036327934




Exp 1, Number elapsed iterations: 998, last acc: 0.9941089837997055




Exp 1, Number elapsed iterations: 999, last acc: 0.9911634756995582




Exp 1, Number elapsed iterations: 1000, last acc: 0.990672557682867




Exp 1, Number elapsed iterations: 1001, last acc: 0.990672557682867
