In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, seed_everything
from torch import Tensor, nn
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy, F1Score, MetricCollection, Precision, Recall
from torchvision import transforms
from torchvision.datasets import MNIST

from energizer import Trainer
from energizer.query_strategies import EntropyStrategy, LeastConfidenceStrategy, MarginStrategy, RandomStrategy

Load and preprocess data, and prepare dataloaders

In [None]:
# load and preprocess datasets
data_dir = "./data"
preprocessing_pipe = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
train_set = MNIST(data_dir, train=True, download=True, transform=preprocessing_pipe)
test_set = MNIST(data_dir, train=False, download=True, transform=preprocessing_pipe)
train_set, val_set = random_split(train_set, [55000, 5000])

# create dataloaders
batch_size = 32
eval_batch_size = 128  # this is use when evaluating on the pool too
train_dl = DataLoader(train_set, batch_size=batch_size, num_workers=os.cpu_count())
val_dl = DataLoader(val_set, batch_size=eval_batch_size, num_workers=os.cpu_count())
test_dl = DataLoader(test_set, batch_size=eval_batch_size, num_workers=os.cpu_count())

Define the model

In [None]:
class MNISTModel(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 128),
            nn.Dropout(),
            nn.Linear(128, 10),
        )
        for stage in ("train", "val", "test"):
            setattr(self, f"{stage}_accuracy", Accuracy())

    def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
        # NOTE: here I am unpacking the batch in the forward pass
        x, _ = batch
        return self.model(x)

    def loss(self, logits: Tensor, targets: Tensor) -> Tensor:
        return F.cross_entropy(logits, targets)

    def common_step(self, batch: Tuple[Tensor, Tensor], stage: str) -> Dict[str, Tensor]:
        logits = self(batch)
        loss = self.loss(logits, y)
        accuracy = getattr(self, f"{stage}_accuracy")(logits, y)
        self.log(f"{stage}/loss", loss, on_epoch=True, on_step=True, prog_bar=True)
        self.log(f"{stage}/accuracy", accuracy, on_epoch=True, on_step=True, prog_bar=True)
        return {"loss": loss, "logits": logits}

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.common_step(batch, "train")

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.common_step(batch, "val")

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.common_step(batch, "test")

    def configure_optimizers(self) -> None:
        return torch.optim.SGD(self.parameters(), lr=0.01)

### Define active learning strategies

We implement the following strategies:

- `RandomStrategy`: selects random instances from the pool. Therefore, it does not need to run any computation on the pool. Thus, we inherit from the `NoAccumulatorStrategy` base class so that we can speed up the computations. As it does not need to run on the pool so we do not need to implement the `pool_step` method, we only need to implement the `query` method.

- `EntropyStrategy`: selects instances that the model is most uncertain about, where the uncertainty is defined as the entropy of the predicted probability distribution over the classes. It needs to run operations on the pool, thus we inherit from the `AccumulatorStrategy` base class. Since it needs to run on the pool, we need to implement the `pool_step` method, but we do not need to implement `query` as `AccumulatorStrategy` knows how to compute the top-k operation and return the indices.

- `LeastConfidenceStrategy`: selects instances that the model is most uncertain about, where the uncertainty is defined as the value of the smallest class probability. It needs to run operations on the pool like the `EntropyStrategy` 

- `MarginConfidenceStrategy`: selects instances that the model is most uncertain about, where the uncertainty is defined as the difference between the first and the second biggest highest class probabilities. It needs to run operations on the pool like the `EntropyStrategy` 


Note that both strategies are already available in the library directly.

In [None]:
class RandomStrategy(NoAccumulatorStrategy):
    def query(self) -> List[int]:
        pool_size = self.trainer.datamodule.pool_size
        return np.random.randint(low=0, high=pool_size, size=self.query_size).tolist()


class EntropyStrategy(AccumulatorStrategy):
    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        """NOTE: since we are defining the `pool_step` ourselves, we can define
        the logic to unpack the batch directly here. When using a pre-defined strategy,
        we need to implement the `get_inputs_from_batch` hook, unless the forward
        method of the model you defined is able to run on the batch "as-is" from the
        dataloader.
        """
        x, _ = batch
        logits = self(x)
        # use the entropy scoring function
        scores = entropy(logits)
        return scores


class LeastConfidenceStrategy(AccumulatorStrategy):
    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        x, _ = batch
        logits = self(x)
        # use the entropy scoring function
        scores = least_confidence(logits)
        return scores


class MarginStrategy(AccumulatorStrategy):
    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        x, _ = batch
        logits = self(x)
        # use the entropy scoring function
        scores = margin_confidence(logits)
        return scores


class ExpectedMarginStrategy(MCAccumulatorStrategy):
    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        x, _ = batch
        logits = self(x)
        # use the entropy scoring function
        scores = expected_margin_confidence(logits)
        return scores


class ExpectedEntropyStrategy(MCAccumulatorStrategy):
    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        x, _ = batch
        logits = self(x)
        # use the entropy scoring function
        scores = expected_entropy(logits)
        return scores


class BALDStrategy(MCAccumulatorStrategy):
    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        x, _ = batch
        logits = self(x)
        # use the entropy scoring function
        scores = bald(logits)
        return scores

In [None]:
entropy(torch.tensor([0.01, 0.99]))

In [None]:
entropy(torch.tensor([0.5, 0.5]))

In [None]:
# instantiate model
model = MNISTModel()

NOTE: when passing a model to build a strategy, internally a `deepcopy` will be created. This is done to avoid modifying the model state and passing it around when trying other strategies. It avoids messing up benchmarks

In [None]:
random_strategy = RandomStrategy(model)
entropy_strategy = EntropyStrategy(model)
leastconfidence_strategy = LeastConfidenceStrategy(model)
margin_strategy = MarginStrategy(model)
expected_entropy_strategy = ExpectedEntropyStrategy(model)
expected_margin_strategy = ExpectedMarginStrategy(model)
bald_strategy = BALDStrategy(model)

strategies = [
    random_strategy,
    entropy_strategy,
    leastconfidence_strategy,
    margin_strategy,
    expected_entropy_strategy,
    expected_margin_strategy,
    bald_strategy,
]

The forward pass of the strategy internally calls the forward of the underlying module.

In [None]:
x = next(iter(train_dl))[0]

model.eval()
for strategy in strategies:
    strategy.eval()

out_model = model(x)
out_random = random_strategy(x)
out_entropy = entropy_strategy(x)

model.train()
random_strategy.train()
entropy_strategy.train()

assert torch.all(out_model == out_random)
assert torch.all(out_model == out_entropy)

out_model.shape, out_random.shape, out_entropy.shape

## Active fit

For clarity let's pack the trainer kwargs in a dictionary

In [None]:
trainer_kwargs = {
    "query_size": 10,  # 50 new instances will be queried at each iteration
    "max_epochs": 3,  # the underlying model will be fit for 3 epochs
    "max_labelling_epochs": 50,  # how many times to run the active learning loop
    "accelerator": "gpu",  # use the gpu
    "test_after_labelling": True,  # since we have a test set, we test after each labelling iteration
    "limit_val_batches": 0,  # do not validate
    "log_every_n_steps": 1,  # we will have a few batches while training, so log on each
}

results_dict = {}

### Random strategy

In [None]:
seed_everything(42)  # for reproducibility (e.g., dropout)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=random_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
random_df = results.to_pandas()
results_dict["random"] = random_df
random_df

### Entropy strategy

In [None]:
seed_everything(42)  # for reproducibility (e.g., dropout)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=entropy_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
entropy_df = results.to_pandas()
results_dict["entropy"] = entropy_df
entropy_df

### Least confidence strategy

In [None]:
seed_everything(42)  # for reproducibility (e.g., dropout)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=leastconfidence_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
leastconfidence_df = results.to_pandas()
results_dict["leastconfidence"] = leastconfidence_df
leastconfidence_df

### Margin strategy

In [None]:
seed_everything(42)  # for reproducibility (e.g., dropout)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=margin_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
margin_df = results.to_pandas()
results_dict["margin"] = margin_df
margin_df

### Expected entropy strategy

In [None]:
seed_everything(42)  # for reproducibility (e.g., dropout)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=expected_entropy_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
expected_entropy_df = results.to_pandas()
results_dict["expected_entropy"] = expected_entropy_df
expected_entropy_df

### Expected margin confidence

In [None]:
seed_everything(42)  # for reproducibility (e.g., dropout)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=expected_margin_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
expected_margin_df = results.to_pandas()
results_dict["expected_margin"] = expected_margin_df
expected_margin_df

### BALD

In [None]:
seed_everything(42)  # for reproducibility (e.g., dropout)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=bald_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
bald_df = results.to_pandas()
results_dict["bald"] = bald_df
bald_df

### Results
Now let's look at the results

In [None]:
for k, v in results_dict.items():
    plt.plot(v["train_size"], v["test/accuracy_epoch"], label=k)
plt.legend()
plt.show()

In [None]:
for k, v in results_dict.items():
    plt.plot(v["train_size"], v["test/loss_epoch"], label=k)
plt.legend()
plt.show()