### Bayesian MC dropout query strategies

Bayesian query strategies use Monte Carlo (MC) dropout to approximate uncertainty in deep learning models. This works by computing multiple forward passes through a neural network with the dropout layers activated. For this example we are going to use a subset of [_MNIST_](https://archive.ics.uci.edu/dataset/683/mnist+database+of+handwritten+digits) dataset, loaded from _sklearn_.

In [None]:
import numpy as np
import torch
from skorch import NeuralNetClassifier
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from activelearning.AL_cycle import plot_results, strategy_comparison
from activelearning.queries.bayesian.mc_bald import mc_bald
from activelearning.queries.bayesian.mc_max_entropy import mc_max_entropy
from activelearning.queries.bayesian.mc_max_meanstd import mc_max_meanstd
from activelearning.queries.bayesian.mc_max_varratios import mc_max_varratios
from activelearning.queries.representative.random_query import query_random
from activelearning.utils.skorch_nnet import reshapedVGG

torch.manual_seed(123)
np.random.seed(123)

In [None]:
preprocess = transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48235, 0.45882, 0.40784], std=[0.229, 0.224, 0.225]),
    ]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In this case we are going to run the images though a [VGG16](https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html) neural network, as this architecture includes dropout layers. It already includes a feature extraction part, so we can pass images directly to the network.

In [None]:
mnist_data = datasets.MNIST("./data", download=True, transform=preprocess)
dataloader = DataLoader(mnist_data, shuffle=True, batch_size=1500)
X, y = next(iter(dataloader))

# read training data (subset of 1500 images from mnist)
X_train, X_test, y_train, y_test = X[:1000], X[1000:1500], y[:1000], y[1000:1500]
X_train = X_train.reshape(1000, -1)
X_test = X_test.reshape(500, -1)

# assemble initial data
n_initial = 100
initial_idx = np.random.choice(range(len(X_train)), size=n_initial, replace=False)
X_initial = X_train[initial_idx]
y_initial = y_train[initial_idx]


# generate the pool
# remove the initial data from the training dataset
X_pool = np.delete(X_train, initial_idx, axis=0)
y_pool = np.delete(y_train, initial_idx, axis=0)

We evaluate the performance of the VGG classifier that we are going to use on the complete training set. This will serve as reference metric for the active learning query strategies, as we want to reach the same accuracy but with less labeled data.

In [None]:
classifier = NeuralNetClassifier(
    reshapedVGG(num_classes=10),
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    train_split=None,
    max_epochs=15,
    device=device,
)
classifier.fit(X_train, y_train)

goal_acc = classifier.score(X_test, y_test)
print(goal_acc)

To compare how different query strategies perform, we can use the *strategy_comparison* function and pass the strategies to be used. We can also pass more than one number of instances, to check whether a different batch size influences performance. *plot_results* can be used to immediatly plot the output from *strategy_comparison*, or a custom graph can be created from the scores data frame.

In [None]:
n_instances = [32]

scores = strategy_comparison(
    X_train=X_initial,
    y_train=y_initial,
    X_pool=X_pool,
    y_pool=y_pool,
    X_test=X_test,
    y_test=y_test,
    classifier="nnet_bo",
    query_strategies=[mc_bald, mc_max_entropy, mc_max_varratios, query_random, mc_max_meanstd],
    n_instances=[32],
    goal_acc=goal_acc,
    max_epochs=15,
)

In [None]:
plot_results(
    scores,  # output data frame from strategy_comparison
    n_instances=n_instances,
    tot_samples=X_train.shape[0],  # size of the original training set, for scale
    goal_acc=goal_acc,
    figsize=(7, 4),
)