In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Dict, Tuple

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torch import Tensor, nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

from energizer.learners.acquisition_functions import entropy, expected_entropy
from energizer.learners.base import DeterministicMixin, MCDropoutMixin
from energizer.trainer import Trainer

In [3]:
# load and preprocess datasets
data_dir = "./data"
preprocessing_pipe = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
train_set = MNIST(data_dir, train=True, download=False, transform=preprocessing_pipe)
test_set = MNIST(data_dir, train=False, download=False, transform=preprocessing_pipe)
train_set, val_set = random_split(train_set, [55000, 5000])

# create dataloaders
batch_size = 32
eval_batch_size = 32  # this is use when evaluating on the pool too
train_dl = DataLoader(train_set, batch_size=batch_size)
val_dl = DataLoader(val_set, batch_size=eval_batch_size)
test_dl = DataLoader(test_set, batch_size=eval_batch_size)

In [4]:
class MNISTModel(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 128),
            nn.Dropout(),
            nn.Linear(128, 10),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def loss(self, logits: Tensor, targets: Tensor) -> Tensor:
        return F.cross_entropy(logits, targets)

    def step(self, batch: Tuple[Tensor, Tensor], stage: str) -> Dict[str, Tensor]:
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        self.log(f"{stage}/loss", loss)
        return {"loss": loss, "logits": logits}

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, "train")

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, "val")

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, "test")

    def configure_optimizers(self) -> None:
        return torch.optim.SGD(self.parameters(), lr=0.01)


class DeterministicMNISTModel(MNISTModel, DeterministicMixin):
    """A implememntation of the `Entropy` active learning strategy."""

    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        # define how to perform the forward pass
        x, _ = batch
        logits = self(x)
        # use an acquisition/scoring function
        scores = entropy(logits)
        return scores


class StochasticMNISTModel(MNISTModel, MCDropoutMixin):
    """A implememntation of the `Entropy` active learning strategy.

    In this case we use the MCDropout technique to compute model
    uncertainty. Accordigly, we need to use `expected_entropy` as
    the acquisition function.
    """

    def loss(logits: Tensor, targets: Tensor) -> Tensor:
        # logits has now shape:
        # [num_samples, num_classes, num_iterations]
        # average over num_iterations
        logits = logits.mean(-1)
        return F.cross_entropy(logits, targets)

    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        """A implememntation of the `Entropy` active learning strategy."""
        # define how to perform the forward pass
        x, _ = batch
        logits = self(x)
        # use an acquisition/scoring function
        # NOTE: since we are using MCDropout we need to use the
        # `expected_entropy` acquisition function
        scores = expected_entropy(logits)
        return scores

In [5]:
model = MNISTModel()
deterministic_model = DeterministicMNISTModel()
stochastic_model = StochasticMNISTModel()

In [6]:
x, _ = next(iter(train_dl))
model(x).shape, deterministic_model(x).shape, stochastic_model(x).shape

(torch.Size([32, 10]), torch.Size([32, 10]), torch.Size([32, 10, 10]))

In [7]:
trainer = Trainer(
    query_size=2,
    max_epochs=3,
    max_labelling_epochs=4,
    total_budget=5,
    log_every_n_steps=1,
    test_after_labelling=True,
    # for testing purposes
    limit_train_batches=10,
    limit_val_batches=10,
    limit_test_batches=10,
    limit_pool_batches=10,
)

trainer.active_fit(
    model=deterministic_model,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

trainer.datamodule.stats

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 184 K 
-------------------------------------
184 K     Trainable params
0         Non-trainable params
184 K     Total params
0.738     Total estimated model params size (MB)


                                                                   



Pool DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 52.90it/s]




Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 47.72it/s, loss=1.61, v_num=48]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 46.10it/s, loss=1.61, v_num=48]




Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 58.84it/s]
Pool DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 47.44it/s]
Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 25.03it/s, loss=1.81, v_num=48]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 24.08it/s, loss=1.81, v_num=48]
Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 36.75it/s]
Pool DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 37.09it/s]
Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 41.03it/s, loss=1.95, v_num=48]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 39.38it/s, loss=1.95, v_num=48]


{'total_data_size': 55000,
 'train_size': 6,
 'pool_size': 54994,
 'num_train_batches': 1,
 'num_pool_batches': 1719}