In [1]:
%load_ext autoreload
%autoreload 2

In [13]:
from copy import deepcopy
from typing import Dict, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer as PLTrainer
from torch import Tensor, nn
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy, F1Score, MetricCollection, Precision, Recall
from torchvision import transforms
from torchvision.datasets import MNIST

from energizer import AccumulatorStrategy, RandomStrategy, Trainer
from energizer.acquisition_functions import entropy, expected_entropy

In [22]:
# load and preprocess datasets
data_dir = "./data"
preprocessing_pipe = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
train_set = MNIST(data_dir, train=True, download=True, transform=preprocessing_pipe)
test_set = MNIST(data_dir, train=False, download=True, transform=preprocessing_pipe)
train_set, val_set = random_split(train_set, [55000, 5000])

# create dataloaders
batch_size = 32
eval_batch_size = 128  # this is use when evaluating on the pool too
train_dl = DataLoader(train_set, batch_size=batch_size)
val_dl = DataLoader(val_set, batch_size=eval_batch_size)
test_dl = DataLoader(test_set, batch_size=eval_batch_size)

AssertionError: 

In [4]:
class MNISTModel(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 128),
            nn.Dropout(),
            nn.Linear(128, 10),
        )
        for stage in ("train", "val", "test"):
            setattr(self, f"{stage}_accuracy", Accuracy())

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def loss(self, logits: Tensor, targets: Tensor) -> Tensor:
        return F.cross_entropy(logits, targets)

    def step(self, batch: Tuple[Tensor, Tensor], stage: str) -> Dict[str, Tensor]:
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        accuracy = getattr(self, f"{stage}_accuracy")(logits, y)
        self.log(f"{stage}/loss", loss, on_epoch=True, on_step=True, prog_bar=True)
        self.log(f"{stage}/accuracy", accuracy, on_epoch=True, on_step=True, prog_bar=True)
        return {"loss": loss, "logits": logits}

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, "train")

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, "val")

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        return self.step(batch, "test")

    def configure_optimizers(self) -> None:
        return torch.optim.SGD(self.parameters(), lr=0.01)

In [5]:
class EntropyStrategy(AccumulatorStrategy):
    """A implememntation of the `Entropy` active learning strategy."""

    def pool_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        # define how to perform the forward pass
        x, _ = batch
        logits = self(x)
        # use an acquisition/scoring function
        scores = entropy(logits)
        return scores

In [6]:
model = MNISTModel()
entropy_strategy = EntropyStrategy(model)
random_strategy = RandomStrategy(model)

x, _ = next(iter(train_dl))
model(x).shape, entropy_strategy(x).shape, random_strategy(x).shape

(torch.Size([32, 10]), torch.Size([32, 10]), torch.Size([32, 10]))

## Active fit

In [7]:
model = MNISTModel()

### Random strategy

In [8]:
random_strategy = RandomStrategy(deepcopy(model))


trainer = Trainer(
    query_size=100,
    max_epochs=3,
    max_labelling_epochs=5,
    accelerator="gpu",
    # total_budget=5,
    test_after_labelling=True,
    # for testing purposes
    # limit_train_batches=10,
    limit_val_batches=1,
    # limit_test_batches=10,
    # limit_pool_batches=10,
    # log_every_n_steps=1,
)

results = trainer.active_fit(
    model=random_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
[1;36m[2022-08-27 14:36:30] energizer/DETAIL[0m ~ [1;33mtrainer:227[0m$ Trainer: trainer active_fit stage[0m
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type       | Params
----------------------------------------------
0 | model          | Sequential | 184 K 
1 | train_accuracy | Accuracy   | 0     
2 | val_accuracy   | Accuracy   | 0     
3 | test_accuracy  | Accuracy   | 0     
----------------------------------------------
184 K     Trainable params
0         Non-trainable params
184 K     Total params
0.738     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
[1;32m[2022-08-27 14:36:33] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `RandomStrategy`[0m


-------------------------Labelling Iteration 0--------------------------


[1;32m[2022-08-27 14:36:33] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 14:36:35] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `RandomStrategy`[0m
  rank_zero_warn(
[1;32m[2022-08-27 14:36:35] energizer/INFO[0m ~ [1;33mactive_learning_loop:195[0m$ [37mQueried 100 instance[0m
[1;32m[2022-08-27 14:36:35] energizer/INFO[0m ~ [1;33mactive_learning_loop:283[0m$ [37mAnnotated 100 instances.[0m
[1;32m[2022-08-27 14:36:35] energizer/INFO[0m ~ [1;33mactive_learning_loop:284[0m$ [37mNew data statistics
num_pool_batches: 429
num_train_batches: 4
pool_size: 54900
total_data_size: 55000
train_size: 100
[0m


-------------------------Labelling Iteration 1--------------------------


[1;32m[2022-08-27 14:36:35] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m
[1;32m[2022-08-27 14:36:35] energizer/INFO[0m ~ [1;33mactive_learning_loop:252[0m$ [37mMNISTModel state dict has been re-initialized[0m
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;32m[2022-08-27 14:36:36] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 14:36:38] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `RandomStrategy`[0m
[1;32m[2022-08-27 14:36:38] energizer/INFO[0m ~ [1;33mactive_learning_loop:195[0m$ [37mQueried 100 instance[0m
[1;32m[2022-08-27 14:36:38] energizer/INFO[0m ~ [1;33mactive_learning_loop:283[0m$ [37mAnnotated 100 instances.[0m
[1;32m[2022-08-27 14:36:38] energizer/INFO[0m ~ [1;33mactive_learning_loop:284[0m$ [37mNew data statistics
num_pool_batches: 429
num_train_batches: 7
pool_size: 54800
total_data_size: 55000
train_size: 200
[0m


-------------------------Labelling Iteration 2--------------------------


[1;32m[2022-08-27 14:36:38] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m
[1;32m[2022-08-27 14:36:38] energizer/INFO[0m ~ [1;33mactive_learning_loop:252[0m$ [37mMNISTModel state dict has been re-initialized[0m
  rank_zero_warn(


Training: 4it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;32m[2022-08-27 14:36:38] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 14:36:40] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `RandomStrategy`[0m
[1;32m[2022-08-27 14:36:40] energizer/INFO[0m ~ [1;33mactive_learning_loop:195[0m$ [37mQueried 100 instance[0m
[1;32m[2022-08-27 14:36:40] energizer/INFO[0m ~ [1;33mactive_learning_loop:283[0m$ [37mAnnotated 100 instances.[0m
[1;32m[2022-08-27 14:36:40] energizer/INFO[0m ~ [1;33mactive_learning_loop:284[0m$ [37mNew data statistics
num_pool_batches: 428
num_train_batches: 10
pool_size: 54700
total_data_size: 55000
train_size: 300
[0m


-------------------------Labelling Iteration 3--------------------------


[1;32m[2022-08-27 14:36:40] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m
[1;32m[2022-08-27 14:36:40] energizer/INFO[0m ~ [1;33mactive_learning_loop:252[0m$ [37mMNISTModel state dict has been re-initialized[0m
  rank_zero_warn(


Training: 7it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;32m[2022-08-27 14:36:41] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 14:36:43] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `RandomStrategy`[0m
[1;32m[2022-08-27 14:36:43] energizer/INFO[0m ~ [1;33mactive_learning_loop:195[0m$ [37mQueried 100 instance[0m
[1;32m[2022-08-27 14:36:43] energizer/INFO[0m ~ [1;33mactive_learning_loop:283[0m$ [37mAnnotated 100 instances.[0m
[1;32m[2022-08-27 14:36:43] energizer/INFO[0m ~ [1;33mactive_learning_loop:284[0m$ [37mNew data statistics
num_pool_batches: 427
num_train_batches: 13
pool_size: 54600
total_data_size: 55000
train_size: 400
[0m


-------------------------Labelling Iteration 4--------------------------


[1;32m[2022-08-27 14:36:43] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m
[1;32m[2022-08-27 14:36:43] energizer/INFO[0m ~ [1;33mactive_learning_loop:252[0m$ [37mMNISTModel state dict has been re-initialized[0m
  rank_zero_warn(


Training: 10it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;32m[2022-08-27 14:36:44] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 14:36:46] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `RandomStrategy`[0m
[1;32m[2022-08-27 14:36:46] energizer/INFO[0m ~ [1;33mactive_learning_loop:195[0m$ [37mQueried 100 instance[0m
[1;32m[2022-08-27 14:36:46] energizer/INFO[0m ~ [1;33mactive_learning_loop:283[0m$ [37mAnnotated 100 instances.[0m
[1;32m[2022-08-27 14:36:46] energizer/INFO[0m ~ [1;33mactive_learning_loop:284[0m$ [37mNew data statistics
num_pool_batches: 426
num_train_batches: 16
pool_size: 54500
total_data_size: 55000
train_size: 500
[0m


-----------------------------Last fit_loop------------------------------
-------------------------Labelling Iteration 5--------------------------


[1;32m[2022-08-27 14:36:46] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m
[1;32m[2022-08-27 14:36:46] energizer/INFO[0m ~ [1;33mactive_learning_loop:252[0m$ [37mMNISTModel state dict has been re-initialized[0m
  rank_zero_warn(


Training: 13it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;32m[2022-08-27 14:36:47] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `MNISTModel`[0m


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 14:36:49] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `RandomStrategy`[0m


In [12]:
random_df = results.to_pandas()
random_df

Unnamed: 0,train_size,test/loss_epoch,test/accuracy_epoch
0,0,2.327604,0.0754
1,100,2.260243,0.2224
2,200,2.235868,0.1765
3,300,2.188535,0.245
4,400,2.141908,0.4123
5,500,2.096199,0.5322


### Entropy strategy

In [None]:
results

In [None]:
entropy_strategy = EntropyStrategy(deepcopy(model))

trainer = Trainer(
    query_size=100,
    max_epochs=3,
    max_labelling_epochs=5,
    accelerator="gpu",
    # total_budget=5,
    test_after_labelling=True,
    # for testing purposes
    # limit_train_batches=10,
    limit_val_batches=1,
    # limit_test_batches=10,
    # limit_pool_batches=10,
    # log_every_n_steps=1,
)

results = trainer.active_fit(
    model=entropy_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
entropy_df = pd.DataFrame(
    data=[(l.data_stats["train_size"], *l.test_outputs[0].values()) for l in results],
    columns=("train_size", *results[0].test_outputs[0].keys()),
)
entropy_df

In [None]:
from collections import UserList

### Fit Logging

In [None]:
model = MNISTModel()

trainer = Trainer(
    query_size=2,
    max_epochs=3,
    max_labelling_epochs=4,
    total_budget=5,
    log_every_n_steps=1,
    test_after_labelling=True,
    # for testing purposes
    limit_train_batches=10,
    limit_val_batches=10,
    limit_test_batches=10,
    limit_pool_batches=10,
)

trainer.fit(
    model=model,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
)

In [None]:
model = MNISTModel()

trainer = Trainer(
    query_size=2,
    max_epochs=3,
    max_labelling_epochs=4,
    total_budget=5,
    log_every_n_steps=1,
    test_after_labelling=True,
    # for testing purposes
    limit_train_batches=10,
    limit_val_batches=10,
    limit_test_batches=10,
    limit_pool_batches=10,
)

trainer.test(
    model=model,
    dataloaders=test_dl,
)

In [None]:
model = MNISTModel()

pl_trainer = PLTrainer(
    max_epochs=3,
    log_every_n_steps=1,
    # for testing purposes
    limit_train_batches=10,
    limit_val_batches=10,
    limit_test_batches=10,
)

results = pl_trainer.test(
    model=model,
    dataloaders=test_dl,
)

In [None]:
"c".center(3, "-")

In [None]:
trainer.validate(
    model=random_strategy.model,
    dataloaders=test_dl,
)