In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from energizer.datastores import PandasDataStoreForSequenceClassification
from energizer.estimators.estimator import Estimator
from transformers import AutoModelForSequenceClassification
from typing import Dict, List
import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy, F1Score, Precision, Recall
from transformers import AutoModelForSequenceClassification
from energizer.enums import InputKeys, OutputKeys, RunningStage
import numpy as np
from energizer.utilities import move_to_cpu
from lightning.fabric.loggers import TensorBoardLogger
from lightning.fabric import seed_everything
from energizer.callbacks import GradNorm, PytorchTensorboardProfiler
from energizer.strategies import RandomStrategy

In [3]:
ds = PandasDataStoreForSequenceClassification.load("./agnews_datastore/")

In [4]:
class ActiveEstimatorForSequenceClassification(RandomStrategy):

    def train_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> torch.Tensor:
        return self.step(model, batch, metrics, RunningStage.TRAIN)

    def validation_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> torch.Tensor:
        return self.step(model, batch, metrics, RunningStage.VALIDATION)

    def test_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> torch.Tensor:
        return self.step(model, batch, metrics, RunningStage.TEST)
    
    def train_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> float:
        return self.epoch_end(output, metrics, RunningStage.TRAIN)

    def validation_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> float:
        return self.epoch_end(output, metrics, RunningStage.VALIDATION)

    def test_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> float:
        return self.epoch_end(output, metrics, RunningStage.TEST)

    def step(
        self,
        model,
        batch: Dict,
        metrics: MetricCollection,
        stage: RunningStage,
    ) -> torch.Tensor:
        
        _ = batch.pop(InputKeys.ON_CPU, None)

        out = model(**batch)
        out_metrics = metrics(out.logits, batch[InputKeys.TARGET])

        if stage == RunningStage.TRAIN:
            logs = {OutputKeys.LOSS: out.loss, **out_metrics}
            self.log_dict({f"{stage}/{k}": v for k, v in logs.items()}, step=self.progress_tracker.global_batch)

        return out.loss
    
    def epoch_end(self, output: List[np.ndarray], metrics: MetricCollection, stage: RunningStage) -> float:
        aggregated_metrics = move_to_cpu(metrics.compute())  # NOTE: metrics are still on device
        aggregated_loss = round(np.mean(output).item(), 6)
        
        logs = {OutputKeys.LOSS: aggregated_loss, **aggregated_metrics}
        self.log_dict({f"{stage}_end/{k}": v for k, v in logs.items()}, step=self.progress_tracker.safe_global_epoch)

        return aggregated_loss

    def configure_metrics(self, *_) -> MetricCollection:
        num_classes = self.model.num_labels
        task = "multiclass"
        # NOTE: you are in charge of moving it to the correct device
        return MetricCollection(
            {
                "accuracy": Accuracy(task, num_classes=num_classes),
                "f1_macro": F1Score(task, num_classes=num_classes, average="macro"),
                "precision_macro": Precision(task, num_classes=num_classes, average="macro"),
                "recall_macro": Recall(task, num_classes=num_classes, average="macro"),
                "f1_micro": F1Score(task, num_classes=num_classes, average="micro"),
                "precision_micro": Precision(task, num_classes=num_classes, average="micro"),
                "recall_micro": Recall(task, num_classes=num_classes, average="micro"),
            }
        ).to(self.device)

In [5]:
seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

estimator = ActiveEstimatorForSequenceClassification(
    model, 
    accelerator="gpu", 
    loggers=[TensorBoardLogger("./")],
    callbacks=[GradNorm(2), PytorchTensorboardProfiler("./profiler_logs")],
    seed=42,
)

Global seed set to 42
Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSe

In [6]:
ds.prepare_for_loading()

In [7]:
results = estimator.active_fit(
    datastore=ds, 
    query_size=50, 
    max_rounds=3, 
    limit_pool_batches=10, 
    limit_test_batches=10,
)

Completed rounds:   0%|          | 0/4 [00:00<?, ?it/s]

Completed epochs: 0it [00:00, ?it/s]

Epoch 0: 0it [00:00, ?it/s]

Test: 0it [00:00, ?it/s]

STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-05-02 16:07:11 1001783:1001783 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
STAGE:2023-05-02 16:07:12 1001783:100178

[{<RunningStage.TEST: 'test'>: 1.377933},
 {'fit': [(1.400006, []), (1.288729, []), (1.22965, [])],
  <RunningStage.TEST: 'test'>: 1.196972},
 {'fit': [(1.374332, []), (1.221406, []), (1.005468, [])],
  <RunningStage.TEST: 'test'>: 1.05405},
 {'fit': [(1.369444, []), (1.147419, []), (0.792541, [])],
  <RunningStage.TEST: 'test'>: 0.788653}]