In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("..")
from typing import Any, Dict, Optional

import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from pytorch_lightning.trainer.states import RunningStage
from torch import Tensor, nn
from torch.utils.data import DataLoader, random_split
from torchmetrics import F1, Accuracy, MetricCollection
from torchvision import transforms
from torchvision.datasets import MNIST

from energizer.data import ActiveDataModule
from energizer.inference import Deterministic
from energizer.loops import ActiveLearningLoop
from energizer.strategies import LeastConfidenceStrategy, RandomStrategy

In [None]:
class MNISTDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str = "./data",
        batch_size: int = 32,
        shuffle: Optional[bool] = False,
        num_workers: int = 2,
        pin_memory: bool = False,
        drop_last: bool = False,
        persistent_workers: bool = False,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.persistent_workers = persistent_workers
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict" or stage is None:
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def _make_dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=self.drop_last,
            persistent_workers=self.persistent_workers,
        )

    def train_dataloader(self):
        return self._make_dataloader(self.mnist_train)

    def val_dataloader(self):
        return self._make_dataloader(self.mnist_val)

    def test_dataloader(self):
        return self._make_dataloader(self.mnist_test)

    def predict_dataloader(self):
        return self._make_dataloader(self.mnist_predict)

In [None]:
dm = MNISTDataModule(batch_size=128, num_workers=0)
dm.prepare_data()
dm.setup()

In [54]:
class Model(LightningModule):
    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5),
            nn.Dropout2d(),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 128),
            nn.Dropout(),
            nn.Linear(128, num_classes),
        )
        self.loss = nn.CrossEntropyLoss()
        metrics = MetricCollection([Accuracy(), F1(num_classes=num_classes, average="macro")])
        setattr(self, f"{RunningStage.TRAINING}_metrics", metrics.clone(prefix="train_"))
        setattr(self, f"{RunningStage.VALIDATING}_metrics", metrics.clone(prefix="val_"))
        setattr(self, f"{RunningStage.TESTING}_metrics", metrics.clone(prefix="test_"))

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def step(self, batch) -> None:
        outputs = {}
        x, y = batch
        logits = self(x)
        outputs["loss"] = self.loss(logits, y)

        # detach to avoid future warning
        outputs["targets"] = y.detach()
        outputs["logits"] = logits.detach()
        return outputs

    def step_end(
        self, outputs, running_stage: RunningStage, on_step: bool = True, on_epoch: bool = True, prog_bar: bool = True
    ) -> None:
        self.log(f"{running_stage}_loss", outputs["loss"], on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar)
        self.log_dict(
            getattr(self, f"{running_stage}_metrics")(outputs["logits"], outputs["targets"]),
            on_step=on_step,
            on_epoch=on_epoch,
            prog_bar=prog_bar,
        )

    def training_step(self, batch, *args, **kwargs) -> Dict[str, Any]:
        return self.step(batch)

    def training_step_end(self, outputs: Dict[str, Any]) -> None:
        self.step_end(outputs, running_stage=RunningStage.TRAINING)

    def validation_step(self, batch, *args, **kwargs) -> Dict[str, Any]:
        return self.step(batch)

    def validation_step_end(self, outputs: Dict[str, Any]) -> None:
        self.step_end(outputs, running_stage=RunningStage.VALIDATING, on_step=False)

    def test_step(self, batch, *args, **kwargs) -> Dict[str, Any]:
        return self.step(batch)

    def test_step_end(self, outputs: Dict[str, Any]) -> None:
        self.step_end(outputs, running_stage=RunningStage.TESTING, on_step=False)

    def predict_step(self, batch, *args, **kwargs) -> Tensor:
        x, _ = batch
        return self(x)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01)

In [None]:
seed_everything(1111)
model = Model()
trainer = Trainer(max_epochs=1)
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)

In [55]:
seed_everything(1111)
adm = ActiveDataModule(
    num_workers=0,
    train_dataset=dm.mnist_val,
    val_dataset=dm.mnist_val,
    test_dataset=dm.mnist_val,
    num_classes=10,
    initial_labels=1_000,
    batch_size=1_000,
)
model = Model(num_classes=adm.num_classes)
trainer = Trainer(max_epochs=1)
active_learning_loop = ActiveLearningLoop(
    # strategy=RandomStrategy(),
    strategy=LeastConfidenceStrategy(inference_module=Deterministic()),
    query_size=1,
    total_budget=2_000,
    reset_weights=True,
)
active_learning_loop.connect(trainer)
trainer.fit_loop = active_learning_loop
trainer.fit(model, datamodule=adm)

Global seed set to 1111
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name             | Type             | Params
------------------------------------------------------
0 | model            | Sequential       | 184 K 
1 | loss             | CrossEntropyLoss | 0     
2 | train_metrics    | MetricCollection | 0     
3 | validate_metrics | MetricCollection | 0     
4 | test_metrics     | MetricCollection | 0     
------------------------------------------------------
184 K     Trainable params
0         Non-trainable params
184 K     Total params
0.738     Total estimated model params size (MB)


                                                                      

Global seed set to 1111



EPOCH
Active learning dataset: ActiveDataset({
    original_dataset_size: 5000,
    labelled_size: 1000,
    pool_size: 4000,
    base_class: <class 'torch.utils.data.dataset.Subset'>,
})
Testing:   0%|          | 0/5 [05:34<?, ?it/s].07it/s, loss=2.34, v_num=46, train_loss_step=2.340, train_Accuracy_step=0.108, train_F1_step=0.0954]
Epoch 0: 100%|██████████| 6/6 [00:02<00:00,  2.45it/s, loss=2.34, v_num=46, train_loss_step=2.340, train_Accuracy_step=0.108, train_F1_step=0.0954, validate_loss=2.310, val_Accuracy=0.102, val_F1=0.0364, train_loss_epoch=2.340, train_Accuracy_epoch=0.108, train_F1_epoch=0.0954]
Testing: 100%|██████████| 5/5 [00:01<00:00,  4.15it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_Accuracy': 0.10180000215768814,
 'test_F1': 0.03642941266298294,
 'test_loss': 2.3051962852478027}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 5/5 [00:01