In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Dict, Tuple

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

from energizer.learners.acquisition_functions import entropy
from energizer.learners.base import Deterministic
from energizer.trainer import Trainer

In [3]:
# load and preprocess datasets
data_dir = "./data"
preprocessing_pipe = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
train_set = MNIST(data_dir, train=True, download=False, transform=preprocessing_pipe)
test_set = MNIST(data_dir, train=False, download=False, transform=preprocessing_pipe)
train_set, val_set = random_split(train_set, [55000, 5000])

# create dataloaders
batch_size = 32
eval_batch_size = 32  # this is use when evaluating on the pool too
train_dl = DataLoader(train_set, batch_size=batch_size)
val_dl = DataLoader(val_set, batch_size=eval_batch_size)
test_dl = DataLoader(test_set, batch_size=eval_batch_size)

In [4]:
# class MNISTModel(LightningModule):
class MNISTModel(Deterministic):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 128),
            nn.Dropout(),
            nn.Linear(128, 10),
        )
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def step(self, batch: Tuple[Tensor, Tensor], batch_idx: int, stage: str) -> Dict[str, Tensor]:
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        self.log(f"{stage}/loss", loss)
        return {"loss": loss, "logits": logits}

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, batch_idx, "val")

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, batch_idx, "test")

    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        """A implememntation of the `Entropy` active learning strategy."""

        # define how to perform the forward pass
        x, _ = batch
        logits = self(x)

        # use an acquisition/scoring function
        scores = entropy(logits)

        return scores

    def configure_optimizers(self) -> None:
        return torch.optim.SGD(self.parameters(), lr=0.01)


In [8]:
model = MNISTModel()

trainer = Trainer(
    query_size=2,
    max_epochs=3,
    max_labelling_epochs=4,
    total_budget=5,
    log_every_n_steps=1,
    test_after_labelling=True,
    
    # for testing purposes
    limit_train_batches=10,
    limit_val_batches=10,
    limit_test_batches=10,
    limit_pool_batches=10,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.active_fit(
    model=model,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)


  | Name  | Type             | Params
-------------------------------------------
0 | model | Sequential       | 184 K 
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
184 K     Trainable params
0         Non-trainable params
184 K     Total params
0.738     Total estimated model params size (MB)


                                                                   



Pool DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 53.40it/s]




Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 48.23it/s, loss=1.85, v_num=44]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 45.75it/s, loss=1.85, v_num=44]




Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 48.80it/s]
Pool DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 52.99it/s]
Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 32.08it/s, loss=1.79, v_num=44]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 31.08it/s, loss=1.79, v_num=44]
Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 45.20it/s]
Pool DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 54.49it/s]
Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 48.15it/s, loss=1.89, v_num=44]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 11/11 [00:00<00:00, 46.34it/s, loss=1.89, v_num=44]


In [10]:
trainer.datamodule.stats

{'total_data_size': 55000,
 'train_size': 6,
 'pool_size': 54994,
 'num_train_batches': 1,
 'num_pool_batches': 1719}