In [1]:
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from skorch import NeuralNetClassifier
from modAL.models import ActiveLearner
from batchbald_redux import batchbald
from acquisition_functions import *
import os
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from keras.layers import Input, Dense, Lambda, Layer
from keras.models import Model
from keras import backend as K
import keras
from keras.datasets import mnist

In [3]:
LATENT_DIMS = [64]
LATENT_DIM = 64
MAX_EPOCHS = 200
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EXPERIMENT_COUNT = 3
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
HIDDEN_DIM = 512
ORIGINAL_DIM = 784
results_path = 'results\\mnist_infovae_dbal'
epsilon_std = 1.0

if not os.path.exists(results_path):
    os.makedirs(results_path)

ACQ_FUNCS = {
    "var_ratios": var_ratios,
    "mean_std": mean_std,
    "max_entropy": max_entropy,
    "bald": bald,
    "uniform": uniform,
#     "batch_bald": batch_bald
}


In [4]:
class MLP_REG(nn.Module):
    def __init__(self, latent_dim):
        super(MLP_REG, self).__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 10),)
        
        
    def forward(self, x):
        return self.layers(x)


### read training data

In [5]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(60000, 28, 28, 1).astype('float32') / 255.
X_test = X_test.reshape(10000, 28, 28, 1).astype('float32') / 255.


In [6]:
X_train_enhanced = np.load(r"C:\Users\pinar\OneDrive\Masaüstü\masterthesis\src\Generative Models\mnist\x_train_enhanced_infovae.npy")
X_test_enhanced = np.load(r"C:\Users\pinar\OneDrive\Masaüstü\masterthesis\src\Generative Models\mnist\x_test_enhanced_infovae.npy")

In [7]:
X_train = X_train.reshape(X_train.shape[0], -1)
X_test = X_test.reshape(X_test.shape[0], -1)

In [8]:
def active_learning_procedure(query_strategy,
                              X_test,
                              y_test,
                              X_pool,
                              y_pool,
                              X_initial,
                              y_initial,
                              estimator,
                              n_queries=98,
                              n_instances=10):
    learner = ActiveLearner(estimator=estimator,
                            X_training=X_initial,
                            y_training=y_initial,
                            query_strategy=query_strategy,
                           )
    perf_hist = [learner.score(X_test, y_test)]
    active_pool_size = [len(X_initial)]
    pool_size = len(X_initial)
    for index in range(n_queries):
        query_idx, query_instance = learner.query(X_pool, n_instances)
        learner.teach(X_pool[query_idx], y_pool[query_idx])
        X_pool = np.delete(X_pool, query_idx, axis=0)
        y_pool = np.delete(y_pool, query_idx, axis=0)
        model_accuracy = learner.score(X_test, y_test)
        pool_size = pool_size + n_instances
        print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy))
        perf_hist.append(model_accuracy)
        active_pool_size.append(pool_size)
    return perf_hist, active_pool_size

In [9]:
for latent_dim in LATENT_DIMS:
    for exp_iter in range(EXPERIMENT_COUNT):
        np.random.seed(exp_iter)
        initial_idx = np.array([],dtype=int)
        for i in range(10):
            idx = np.random.choice(np.where(y_train==i)[0], size=2, replace=False)
            initial_idx = np.concatenate((initial_idx, idx))

        for func_name, acquisition_func in ACQ_FUNCS.items():  

            X_initial = X_train_enhanced[initial_idx]
            y_initial = y_train[initial_idx]

            X_pool = np.delete(X_train_enhanced, initial_idx, axis=0)
            y_pool = np.delete(y_train, initial_idx, axis=0)

            model = MLP_REG(latent_dim).to(DEVICE)

            estimator = NeuralNetClassifier(model,
                                          max_epochs=MAX_EPOCHS,
                                          batch_size=BATCH_SIZE,
                                          lr=LEARNING_RATE,
                                          optimizer=torch.optim.Adam,
                                          criterion=torch.nn.CrossEntropyLoss,
                                          train_split=None,
                                          verbose=0,
                                          device=DEVICE)



            acc_arr, dataset_size_arr = active_learning_procedure(acquisition_func,
                                                              X_test_enhanced,
                                                              y_test,
                                                              X_pool,
                                                              y_pool,
                                                              X_initial,
                                                              y_initial,
                                                              estimator,)
            file_name = os.path.join(results_path, "{func_name}_latent_dim_{latent_dim}_exp_{exp_iter}.npy".format(func_name=func_name, exp_iter=exp_iter, latent_dim=latent_dim))
            np.save(file_name, (acc_arr, dataset_size_arr))
        '''
        for func_name, acquisition_func in ACQ_FUNCS.items():  
            X_initial = X_train[initial_idx]
            y_initial = y_train[initial_idx]

            X_pool = np.delete(X_train, initial_idx, axis=0)
            y_pool = np.delete(y_train, initial_idx, axis=0)

            model = MLP_REG(ORIGINAL_DIM).to(DEVICE)

            estimator = NeuralNetClassifier(model,
                                          max_epochs=MAX_EPOCHS,
                                          batch_size=BATCH_SIZE,
                                          lr=LEARNING_RATE,
                                          optimizer=torch.optim.Adam,
                                          criterion=torch.nn.CrossEntropyLoss,
                                          train_split=None,
                                          verbose=0,
                                          device=DEVICE)



            acc_arr, dataset_size_arr = active_learning_procedure(acquisition_func,
                                                              X_test,
                                                              y_test,
                                                              X_pool,
                                                              y_pool,
                                                              X_initial,
                                                              y_initial,
                                                              estimator,)
            file_name = os.path.join("results\\mnist_dbal", "{func_name}_exp_{exp_iter}.npy".format(func_name=func_name, exp_iter=exp_iter))
            np.save(file_name, (acc_arr, dataset_size_arr))
        '''

Accuracy after query 1: 0.6432
Accuracy after query 2: 0.6934
Accuracy after query 3: 0.7153
Accuracy after query 4: 0.7229
Accuracy after query 5: 0.7318
Accuracy after query 6: 0.7767
Accuracy after query 7: 0.7925
Accuracy after query 8: 0.7949
Accuracy after query 9: 0.8046
Accuracy after query 10: 0.8079
Accuracy after query 11: 0.8037
Accuracy after query 12: 0.8442
Accuracy after query 13: 0.8534
Accuracy after query 14: 0.8594
Accuracy after query 15: 0.8662
Accuracy after query 16: 0.8703
Accuracy after query 17: 0.8820
Accuracy after query 18: 0.8910
Accuracy after query 19: 0.8948
Accuracy after query 20: 0.8923
Accuracy after query 21: 0.9064
Accuracy after query 22: 0.9019
Accuracy after query 23: 0.9149
Accuracy after query 24: 0.9204
Accuracy after query 25: 0.9153
Accuracy after query 26: 0.9230
Accuracy after query 27: 0.9083
Accuracy after query 28: 0.9157
Accuracy after query 29: 0.9240
Accuracy after query 30: 0.9257
Accuracy after query 31: 0.9276
Accuracy after qu

Accuracy after query 61: 0.9511
Accuracy after query 62: 0.9510
Accuracy after query 63: 0.9528
Accuracy after query 64: 0.9460
Accuracy after query 65: 0.9531
Accuracy after query 66: 0.9555
Accuracy after query 67: 0.9544
Accuracy after query 68: 0.9512
Accuracy after query 69: 0.9517
Accuracy after query 70: 0.9517
Accuracy after query 71: 0.9538
Accuracy after query 72: 0.9533
Accuracy after query 73: 0.9551
Accuracy after query 74: 0.9593
Accuracy after query 75: 0.9574
Accuracy after query 76: 0.9536
Accuracy after query 77: 0.9560
Accuracy after query 78: 0.9593
Accuracy after query 79: 0.9573
Accuracy after query 80: 0.9574
Accuracy after query 81: 0.9608
Accuracy after query 82: 0.9608
Accuracy after query 83: 0.9594
Accuracy after query 84: 0.9493
Accuracy after query 85: 0.9595
Accuracy after query 86: 0.9608
Accuracy after query 87: 0.9524
Accuracy after query 88: 0.9646
Accuracy after query 89: 0.9645
Accuracy after query 90: 0.9601
Accuracy after query 91: 0.9641
Accuracy

Accuracy after query 19: 0.8336
(2000,)
Accuracy after query 20: 0.8196
(2000,)
Accuracy after query 21: 0.8469
(2000,)
Accuracy after query 22: 0.8515
(2000,)
Accuracy after query 23: 0.8490
(2000,)
Accuracy after query 24: 0.8538
(2000,)
Accuracy after query 25: 0.8536
(2000,)
Accuracy after query 26: 0.8616
(2000,)
Accuracy after query 27: 0.8584
(2000,)
Accuracy after query 28: 0.8624
(2000,)
Accuracy after query 29: 0.8473
(2000,)
Accuracy after query 30: 0.8608
(2000,)
Accuracy after query 31: 0.8642
(2000,)
Accuracy after query 32: 0.8664
(2000,)
Accuracy after query 33: 0.8841
(2000,)
Accuracy after query 34: 0.8863
(2000,)
Accuracy after query 35: 0.8756
(2000,)
Accuracy after query 36: 0.8892
(2000,)
Accuracy after query 37: 0.8838
(2000,)
Accuracy after query 38: 0.8880
(2000,)
Accuracy after query 39: 0.8919
(2000,)
Accuracy after query 40: 0.8884
(2000,)
Accuracy after query 41: 0.8919
(2000,)
Accuracy after query 42: 0.8967
(2000,)
Accuracy after query 43: 0.8917
(2000,)


Accuracy after query 59: 0.9558
Accuracy after query 60: 0.9541
Accuracy after query 61: 0.9530
Accuracy after query 62: 0.9537
Accuracy after query 63: 0.9558
Accuracy after query 64: 0.9565
Accuracy after query 65: 0.9521
Accuracy after query 66: 0.9592
Accuracy after query 67: 0.9566
Accuracy after query 68: 0.9581
Accuracy after query 69: 0.9620
Accuracy after query 70: 0.9583
Accuracy after query 71: 0.9629
Accuracy after query 72: 0.9611
Accuracy after query 73: 0.9625
Accuracy after query 74: 0.9634
Accuracy after query 75: 0.9619
Accuracy after query 76: 0.9639
Accuracy after query 77: 0.9647
Accuracy after query 78: 0.9638
Accuracy after query 79: 0.9639
Accuracy after query 80: 0.9617
Accuracy after query 81: 0.9613
Accuracy after query 82: 0.9631
Accuracy after query 83: 0.9651
Accuracy after query 84: 0.9607
Accuracy after query 85: 0.9616
Accuracy after query 86: 0.9600
Accuracy after query 87: 0.9627
Accuracy after query 88: 0.9676
Accuracy after query 89: 0.9677
Accuracy

Accuracy after query 21: 0.8579
Accuracy after query 22: 0.8461
Accuracy after query 23: 0.8468
Accuracy after query 24: 0.8613
Accuracy after query 25: 0.8654
Accuracy after query 26: 0.8503
Accuracy after query 27: 0.8600
Accuracy after query 28: 0.8575
Accuracy after query 29: 0.8556
Accuracy after query 30: 0.8619
Accuracy after query 31: 0.8658
Accuracy after query 32: 0.8632
Accuracy after query 33: 0.8716
Accuracy after query 34: 0.8831
Accuracy after query 35: 0.8828
Accuracy after query 36: 0.8822
Accuracy after query 37: 0.8887
Accuracy after query 38: 0.8911
Accuracy after query 39: 0.8856
Accuracy after query 40: 0.8953
Accuracy after query 41: 0.8867
Accuracy after query 42: 0.8779
Accuracy after query 43: 0.8958
Accuracy after query 44: 0.8965
Accuracy after query 45: 0.8857
Accuracy after query 46: 0.9009
Accuracy after query 47: 0.9059
Accuracy after query 48: 0.9099
Accuracy after query 49: 0.9037
Accuracy after query 50: 0.9126
Accuracy after query 51: 0.9074
Accuracy

Accuracy after query 57: 0.9545
Accuracy after query 58: 0.9540
Accuracy after query 59: 0.9553
Accuracy after query 60: 0.9542
Accuracy after query 61: 0.9551
Accuracy after query 62: 0.9526
Accuracy after query 63: 0.9559
Accuracy after query 64: 0.9568
Accuracy after query 65: 0.9537
Accuracy after query 66: 0.9586
Accuracy after query 67: 0.9604
Accuracy after query 68: 0.9585
Accuracy after query 69: 0.9573
Accuracy after query 70: 0.9573
Accuracy after query 71: 0.9609
Accuracy after query 72: 0.9588
Accuracy after query 73: 0.9566
Accuracy after query 74: 0.9620
Accuracy after query 75: 0.9594
Accuracy after query 76: 0.9598
Accuracy after query 77: 0.9607
Accuracy after query 78: 0.9618
Accuracy after query 79: 0.9653
Accuracy after query 80: 0.9648
Accuracy after query 81: 0.9599
Accuracy after query 82: 0.9621
Accuracy after query 83: 0.9631
Accuracy after query 84: 0.9641
Accuracy after query 85: 0.9683
Accuracy after query 86: 0.9635
Accuracy after query 87: 0.9642
Accuracy

Accuracy after query 19: 0.8885
Accuracy after query 20: 0.8844
Accuracy after query 21: 0.8955
Accuracy after query 22: 0.8995
Accuracy after query 23: 0.9099
Accuracy after query 24: 0.9149
Accuracy after query 25: 0.9161
Accuracy after query 26: 0.9094
Accuracy after query 27: 0.9173
Accuracy after query 28: 0.9227
Accuracy after query 29: 0.9207
Accuracy after query 30: 0.9265
Accuracy after query 31: 0.9277
Accuracy after query 32: 0.9302
Accuracy after query 33: 0.9322
Accuracy after query 34: 0.9343
Accuracy after query 35: 0.9369
Accuracy after query 36: 0.9349
Accuracy after query 37: 0.9423
Accuracy after query 38: 0.9341
Accuracy after query 39: 0.9425
Accuracy after query 40: 0.9449
Accuracy after query 41: 0.9408
Accuracy after query 42: 0.9378
Accuracy after query 43: 0.9404
Accuracy after query 44: 0.9435
Accuracy after query 45: 0.9476
Accuracy after query 46: 0.9501
Accuracy after query 47: 0.9523
Accuracy after query 48: 0.9473
Accuracy after query 49: 0.9511
Accuracy

Accuracy after query 63: 0.9283
(2000,)
Accuracy after query 64: 0.9254
(2000,)
Accuracy after query 65: 0.9289
(2000,)
Accuracy after query 66: 0.9278
(2000,)
Accuracy after query 67: 0.9282
(2000,)
Accuracy after query 68: 0.9286
(2000,)
Accuracy after query 69: 0.9252
(2000,)
Accuracy after query 70: 0.9287
(2000,)
Accuracy after query 71: 0.9296
(2000,)
Accuracy after query 72: 0.9264
(2000,)
Accuracy after query 73: 0.9286
(2000,)
Accuracy after query 74: 0.9300
(2000,)
Accuracy after query 75: 0.9303
(2000,)
Accuracy after query 76: 0.9292
(2000,)
Accuracy after query 77: 0.9349
(2000,)
Accuracy after query 78: 0.9386
(2000,)
Accuracy after query 79: 0.9340
(2000,)
Accuracy after query 80: 0.9348
(2000,)
Accuracy after query 81: 0.9332
(2000,)
Accuracy after query 82: 0.9353
(2000,)
Accuracy after query 83: 0.9350
(2000,)
Accuracy after query 84: 0.9336
(2000,)
Accuracy after query 85: 0.9342
(2000,)
Accuracy after query 86: 0.9385
(2000,)
Accuracy after query 87: 0.9396
(2000,)
