In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from energizer.datastores import PandasDataStoreForSequenceClassification
from energizer.estimators.estimator import Estimator
from transformers import AutoModelForSequenceClassification
from typing import Dict, List
import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy, F1Score, Precision, Recall
from transformers import AutoModelForSequenceClassification
from energizer.enums import InputKeys, OutputKeys, RunningStage
import numpy as np
from energizer.utilities import move_to_cpu
from lightning.fabric.loggers import TensorBoardLogger
from lightning.fabric import seed_everything
from energizer.callbacks import GradNorm, PytorchTensorboardProfiler, EarlyStopping, ModelCheckpoint
from energizer.strategies import RandomStrategy, UncertaintyBasedStrategy
from energizer.strategies.random import RandomStrategySEALS
from energizer.strategies.uncertainty import UncertaintyBasedStrategySEALS

In [3]:
class EstimatorForSequenceClassification(Estimator):

    def train_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> Dict:
        return self.step(model, batch, metrics, RunningStage.TRAIN)

    def validation_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> Dict:
        return self.step(model, batch, metrics, RunningStage.VALIDATION)

    def test_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> Dict:
        return self.step(model, batch, metrics, RunningStage.TEST)
    
    def train_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> Dict:
        return self.epoch_end(output, metrics, RunningStage.TRAIN)

    def validation_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> Dict:
        return self.epoch_end(output, metrics, RunningStage.VALIDATION)

    def test_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> Dict:
        return self.epoch_end(output, metrics, RunningStage.TEST)

    def step(self, model, batch: Dict, metrics: MetricCollection, stage: RunningStage) -> torch.Tensor:
        _ = batch.pop(InputKeys.ON_CPU, None)

        out = model(**batch)
        out_metrics = metrics(out.logits, batch[InputKeys.TARGET])

        if stage == RunningStage.TRAIN:
            logs = {OutputKeys.LOSS: out.loss, **out_metrics}
            self.log_dict({f"{stage}/{k}": v for k, v in logs.items()}, step=self.progress_tracker.global_batch)

        return out.loss
    
    def epoch_end(self, output: List[np.ndarray], metrics: MetricCollection, stage: RunningStage) -> float:
        aggregated_metrics = move_to_cpu(metrics.compute())  # NOTE: metrics are still on device
        aggregated_loss = round(np.mean(output).item(), 6)
        
        logs = {OutputKeys.LOSS: aggregated_loss, **aggregated_metrics}
        self.log_dict({f"{stage}_end/{k}": v for k, v in logs.items()}, step=self.progress_tracker.safe_global_epoch)

        return aggregated_metrics

    def configure_metrics(self, *_) -> MetricCollection:
        num_classes = self.model.num_labels
        task = "multiclass"
        # NOTE: you are in charge of moving it to the correct device
        return MetricCollection(
            {
                "accuracy": Accuracy(task, num_classes=num_classes),
                "f1_macro": F1Score(task, num_classes=num_classes, average="macro"),
                "precision_macro": Precision(task, num_classes=num_classes, average="macro"),
                "recall_macro": Recall(task, num_classes=num_classes, average="macro"),
                "f1_micro": F1Score(task, num_classes=num_classes, average="micro"),
                "precision_micro": Precision(task, num_classes=num_classes, average="micro"),
                "recall_micro": Recall(task, num_classes=num_classes, average="micro"),
            }
        ).to(self.device)

---
### Random strategy

In [None]:
class RandomStrategySEALSForSequenceClassification(EstimatorForSequenceClassification, RandomStrategySEALS):
    ...


seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

estimator = RandomStrategySEALSForSequenceClassification(
    model, 
    accelerator="gpu", 
    loggers=[TensorBoardLogger("./", name="tb_logs")],
    callbacks=[
        GradNorm(2), 
        ModelCheckpoint("./checkpoints", monitor="f1_macro", stage="train", mode="max"),
        EarlyStopping(monitor="f1_macro", stage="train", interval="epoch", mode="max"),
    ],
    seed=42,
    num_neighbours=100,
)

In [None]:
ds.prepare_for_loading()
results = estimator.active_fit(
    datastore=ds, 
    query_size=50,
    max_rounds=20, 
    min_steps=50,
    reinit_model=True,
    # limit_pool_batches=10, 
    # limit_test_batches=1,
)

In [None]:
class RandomStrategyForSequenceClassification(EstimatorForSequenceClassification, RandomStrategy):
    ...


seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

estimator = RandomStrategyForSequenceClassification(
    model, 
    accelerator="gpu", 
    loggers=[TensorBoardLogger("./", name="tb_logs")],
    callbacks=[
        GradNorm(2), 
        ModelCheckpoint("./checkpoints", monitor="f1_macro", stage="train", mode="max"),
        EarlyStopping(monitor="f1_macro", stage="train", interval="epoch", mode="max"),
    ],
    seed=42,
)

In [None]:
ds.prepare_for_loading()
results = estimator.active_fit(
    datastore=ds, 
    query_size=50,
    max_rounds=20, 
    min_steps=50,
    reinit_model=True,
    # limit_pool_batches=10, 
    # limit_test_batches=10,
)

---
### Entropy SEALS strategy

In [4]:
ds = PandasDataStoreForSequenceClassification.load("../data/prepared/agnews_binarised_bert-tiny/")

In [5]:
ids = ds.sample_from_pool(size=100, mode="stratified", random_state=42)
ds.label(ids, -1)

100

In [6]:
ds.show_batch("train")

{'input_ids': tensor([[  101,  4108,  2373,  5222,  6996,  2006,  2470,  2817,  4108,  2373,
           1998, 10891,  3751,  2097,  2025,  2031,  2000,  3477,  2005,  2047,
           2470,  2046,  2943,  1011,  8122,  3454,  1010,  1996,  2110,  2270,
           2326,  3222,  5451,  9857,  1012,   102]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'labels': tensor([0]),
 <InputKeys.ON_CPU: 'on_cpu'>: {<SpecialKeys.ID: 'unique_id'>: [2081]}}

In [7]:
from src.estimators import UncertaintyBasedStrategyGradSub

In [24]:

seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

class MyEstimator(EstimatorForSequenceClassification, UncertaintyBasedStrategyGradSub):
    ...


estimator = MyEstimator(
    score_fn="entropy",
    model=model, 
    accelerator="gpu", 
    loggers=[TensorBoardLogger("./", name="tb_logs")],
    callbacks=[
        GradNorm(2), 
        ModelCheckpoint("./checkpoints", monitor="f1_macro", stage="train", mode="max"),
        EarlyStopping(monitor="f1_macro", stage="train", interval="epoch", mode="max"),
    ],
    num_neighbours=100,
    num_influential=10,
)

ds.prepare_for_loading()
results = estimator.active_fit(
    datastore=ds, 
    query_size=50,
    max_rounds=20, 
    min_steps=50,
    limit_test_batches=2,
    limit_pool_batches=2,
)

Global seed set to 42
Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSe

Completed rounds:   0%|          | 0/21 [00:00<?, ?it/s]

Completed epochs: 0it [00:00, ?it/s]

Epoch 0: 0it [00:00, ?it/s]

Test: 0it [00:00, ?it/s]

Pool: 0it [00:00, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/35 [00:00<?, ?it/s]

  0%|          | 0/36 [00:00<?, ?it/s]

KeyError: <OutputKeys.METRICS: 'metrics'>

In [None]:
batch = ds.show_batch()
batch.pop("on_cpu", None)
batch = estimator.transfer_to_device(batch)

In [None]:
list(model1.parameters())[0]

In [None]:
params = list(estimator.model.parameters())

In [None]:
params[0]

In [None]:
torch._C._functorch.get_unwrapped(params[0])

In [None]:
model = estimator.fabric.setup(estimator.model)

In [None]:
 (**batch)

In [None]:
class UncertaintyBasedStrategySEALSForSequenceClassification(EstimatorForSequenceClassification, UncertaintyBasedStrategySEALS):
    def pool_step( self, model, batch: Dict, batch_idx: int, metrics: MetricCollection) -> Dict:
        _ = batch.pop(InputKeys.ON_CPU)  # this is already handled in the `evaluation_step`
        logits = model(**batch).logits
        return self.score_fn(logits)

seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

estimator = UncertaintyBasedStrategySEALSForSequenceClassification(
    score_fn="entropy",
    model=model, 
    accelerator="gpu", 
    loggers=[TensorBoardLogger("./", name="tb_logs")],
    callbacks=[
        GradNorm(2), 
        ModelCheckpoint("./checkpoints", monitor="f1_macro", stage="train", mode="max"),
        EarlyStopping(monitor="f1_macro", stage="train", interval="epoch", mode="max"),
    ],
    num_neighbours=100,
)

In [None]:
ds.prepare_for_loading()
results = estimator.active_fit(
    datastore=ds, 
    query_size=50,
    max_rounds=20, 
    min_steps=50,
)

### Entropy

In [None]:
class UncertaintyBasedStrategyForSequenceClassification(EstimatorForSequenceClassification, UncertaintyBasedStrategy):
    def pool_step( self, model, batch: Dict, batch_idx: int, metrics: MetricCollection) -> Dict:
        _ = batch.pop(InputKeys.ON_CPU)  # this is already handled in the `evaluation_step`
        logits = model(**batch).logits
        return self.score_fn(logits)

seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

estimator = UncertaintyBasedStrategyForSequenceClassification(
    score_fn="entropy",
    model=model, 
    accelerator="gpu", 
    loggers=[TensorBoardLogger("./", name="tb_logs")],
    callbacks=[
        GradNorm(2), 
        ModelCheckpoint("./checkpoints", monitor="f1_macro", stage="train", mode="max"),
        EarlyStopping(monitor="f1_macro", stage="train", interval="epoch", mode="max"),
    ],
)

In [None]:
ds.prepare_for_loading()
results = estimator.active_fit(
    datastore=ds, 
    query_size=50,
    max_rounds=20, 
    min_steps=50,
)