In [1]:
from copy import deepcopy
from pprint import pprint

import torch.cuda
from torch import nn, optim
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

from baal import ActiveLearningDataset, ModelWrapper
from baal.active import ActiveLearningLoop
from baal.active.heuristics import BALD
from baal.bayesian.dropout import patch_module
from baal.utils.metrics import Accuracy
import numpy as np

In [2]:
use_cuda = torch.cuda.is_available()
NO_OF_INITIAL_LABELLED = 20
DROPOUT_RATE = 0.4
EPOCH = 50
BATCH_SIZE = 128
NO_OF_ITERATIONS = 200

In [3]:
train_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

In [4]:
train_ds = MNIST("/dataset_mnist", train=True, transform=train_transform, download=True)
test_ds = MNIST("/dataset_mnist", train=False, transform=test_transform, download=True)

In [5]:
al_dataset = ActiveLearningDataset(train_ds, pool_specifics={"transform": test_transform})

In [6]:
def label_uniformly(al_dataset, no_per_class=2):
    initial_dataset_size = al_dataset.labelled.shape[0]
    indices_to_be_labelled = set()
    
    for class_label in range(10):
        n_iter = 0
        while(n_iter < no_per_class):
            idx = np.random.choice(initial_dataset_size, 1)[0]
            selected_label = al_dataset.get_raw(idx)[1]
            if idx not in indices_to_be_labelled and selected_label == class_label:
                indices_to_be_labelled.add(idx)
                n_iter = n_iter + 1
    #print(indices_to_be_labelled)
    for elt in indices_to_be_labelled:
        #print(elt.item())
        al_dataset.label(elt.item())

In [7]:
al_dataset = ActiveLearningDataset(train_ds, pool_specifics={"transform": test_transform})
label_uniformly(al_dataset)
#al_dataset.label_randomly(NO_OF_INITIAL_LABELLED)  # Start with 20 items labelled.

In [8]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 256),
    nn.Dropout(p=0.25),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.Dropout(p=0.25),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.Dropout(p=0.5),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.Softmax(dim=1),
)

In [9]:
model = patch_module(model)
if use_cuda:
    model = model.cuda()

In [10]:
wrapper = ModelWrapper(model=model, criterion=nn.CrossEntropyLoss())
#wrapper.metrics = dict()
#wrapper.add_metric("accuracy", Accuracy)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
bald = BALD()

al_loop = ActiveLearningLoop(
    dataset=al_dataset,
    get_probabilities=wrapper.predict_on_dataset,
    heuristic=bald,
    ndata_to_label=10,  # We will label 100 examples per step.
    # KWARGS for predict_on_dataset
    iterations=100,  # 20 sampling for MC-Dropout
    batch_size=BATCH_SIZE,
    use_cuda=use_cuda,
    verbose=False,
)

In [12]:
initial_weights = deepcopy(model.state_dict())

In [None]:
for step in range(NO_OF_ITERATIONS):
    model.load_state_dict(initial_weights)
    train_loss = wrapper.train_on_dataset(
        al_dataset, optimizer=optimizer, batch_size=BATCH_SIZE, epoch=EPOCH, use_cuda=use_cuda
    )
    test_loss = wrapper.test_on_dataset(test_ds, batch_size=BATCH_SIZE, use_cuda=use_cuda)

    pprint(
        {
            "dataset_size": len(al_dataset),
            #"train_loss": wrapper.metrics["train_loss"].value,
            #"test_loss": wrapper.metrics["test_loss"].value,
            "train_accuracy": wrapper.metrics['train_accuracy'].value,
            "test_accuracy": wrapper.metrics['test_accuracy'].value,
        }
    )
    flag = al_loop.step()
    if not flag:
        # We are done labelling! stopping
        break

[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-16T21:42:10.220703Z [info     ] Starting training              dataset=20 epoch=50
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-16T21:44:18.260256Z [info     ] Training complete              train_accuracy=0.6500000357627869
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-16T21:44:18.274351Z [info     ] Starting evaluating            dataset=10000
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-16T21:44:20.939488Z [info     ] Evaluation complete            test_accuracy=0.2988528609275818
{'dataset_size': 20,
 'test_accuracy': 0.2988528609275818,
 'train_accuracy': 0.6500000357627869}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-16T21:44:20.987698Z [info     ] Start Predict                  dataset=59980
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-16T21:45:25.459887Z [info     ] Starting training     

[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-16T22:47:16.918254Z [info     ] Starting evaluating            dataset=10000
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-16T22:47:19.720834Z [info     ] Evaluation complete            test_accuracy=0.29736945033073425
{'dataset_size': 200,
 'test_accuracy': 0.29736945033073425,
 'train_accuracy': 0.5737847089767456}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-16T22:47:19.767708Z [info     ] Start Predict                  dataset=59800
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-16T22:48:23.651954Z [info     ] Starting training              dataset=210 epoch=50
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-16T22:50:47.229936Z [info     ] Training complete              train_accuracy=0.6656821370124817
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-16T22:50:47.250135Z [info     ] Starting evaluating

[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-16T23:50:50.314353Z [info     ] Start Predict                  dataset=59620
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-16T23:51:53.950857Z [info     ] Starting training              dataset=390 epoch=50
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-16T23:54:21.275323Z [info     ] Training complete              train_accuracy=0.2721354067325592
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-16T23:54:21.285249Z [info     ] Starting evaluating            dataset=10000
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-16T23:54:24.061983Z [info     ] Evaluation complete            test_accuracy=0.17879746854305267
{'dataset_size': 390,
 'test_accuracy': 0.17879746854305267,
 'train_accuracy': 0.2721354067325592}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-16T23:54:24.119819Z [info     ] Start 

[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-17T00:59:50.783553Z [info     ] Training complete              train_accuracy=0.7661098837852478
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-17T00:59:50.796970Z [info     ] Starting evaluating            dataset=10000
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-17T00:59:53.524396Z [info     ] Evaluation complete            test_accuracy=0.5918710231781006
{'dataset_size': 570,
 'test_accuracy': 0.5918710231781006,
 'train_accuracy': 0.7661098837852478}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-17T00:59:53.576613Z [info     ] Start Predict                  dataset=59430
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-17T01:00:57.255620Z [info     ] Starting training              dataset=580 epoch=50
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-17T01:03:25.662645Z [info     ] Training complete   

[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-17T02:05:24.764923Z [info     ] Evaluation complete            test_accuracy=0.609968364238739
{'dataset_size': 750,
 'test_accuracy': 0.609968364238739,
 'train_accuracy': 0.7453362345695496}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-17T02:05:24.815253Z [info     ] Start Predict                  dataset=59250
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-17T02:06:27.654774Z [info     ] Starting training              dataset=760 epoch=50
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-17T02:08:57.249445Z [info     ] Training complete              train_accuracy=0.7202256321907043
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-17T02:08:57.257424Z [info     ] Starting evaluating            dataset=10000
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-17T02:09:00.078558Z [info     ] Evaluation complete    

[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-17T03:12:07.533366Z [info     ] Starting training              dataset=940 epoch=50
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-17T03:14:37.321043Z [info     ] Training complete              train_accuracy=0.7724609375
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-17T03:14:37.339843Z [info     ] Starting evaluating            dataset=10000
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-17T03:14:40.092430Z [info     ] Evaluation complete            test_accuracy=0.6633702516555786
{'dataset_size': 940,
 'test_accuracy': 0.6633702516555786,
 'train_accuracy': 0.7724609375}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-17T03:14:40.140009Z [info     ] Start Predict                  dataset=59060
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-17T03:15:43.247702Z [info     ] Starting training              d

[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-17T04:19:57.967099Z [info     ] Starting evaluating            dataset=10000
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-17T04:20:00.788694Z [info     ] Evaluation complete            test_accuracy=0.8679786324501038
{'dataset_size': 1120,
 'test_accuracy': 0.8679786324501038,
 'train_accuracy': 0.8894676566123962}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-17T04:20:00.838014Z [info     ] Start Predict                  dataset=58880
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-17T04:21:03.427704Z [info     ] Starting training              dataset=1130 epoch=50
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-17T04:23:35.347640Z [info     ] Training complete              train_accuracy=0.8719044923782349
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-17T04:23:35.360062Z [info     ] Starting evaluating

 'train_accuracy': 0.7613636255264282}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11-17T05:26:09.180299Z [info     ] Start Predict                  dataset=58700
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:116] 2021-11-17T05:27:13.337611Z [info     ] Starting training              dataset=1310 epoch=50
[9608-MainThread ] [baal.modelwrapper:train_on_dataset:127] 2021-11-17T05:29:54.520017Z [info     ] Training complete              train_accuracy=0.8399621844291687
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:155] 2021-11-17T05:29:54.527996Z [info     ] Starting evaluating            dataset=10000
[9608-MainThread ] [baal.modelwrapper:test_on_dataset:165] 2021-11-17T05:29:57.353359Z [info     ] Evaluation complete            test_accuracy=0.8639240264892578
{'dataset_size': 1310,
 'test_accuracy': 0.8639240264892578,
 'train_accuracy': 0.8399621844291687}
[9608-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:263] 2021-11

In [None]:
torch.save(wrapper.model.state_dict(), 'model_mnist.pth')