In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2"
BATCH_SIZE = 32
EVAL_BATCH_SIZE = 512
LEARNING_RATE = 0.0001

trainer_kwargs = {
    "query_size": 1,
    "max_epochs": 3,
    "max_labelling_epochs": 5,
    "test_after_labelling": True,
    "accelerator": "gpu",
    "limit_val_batches": 1,
    # total_budget=5,
    # for testing purposes
    # limit_train_batches=10,
    # limit_test_batches=10,
    # limit_pool_batches=10,
    # log_every_n_steps=1,
}

## Pre

In [4]:
import json
import os
from copy import deepcopy
from typing import Any, Dict, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn.functional as F
from datasets import load_dataset
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer as PLTrainer
from pytorch_lightning import seed_everything
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, F1Score, MetricCollection, Precision, Recall
from transformers import (
    AdamW,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    get_constant_schedule_with_warmup,
)

from energizer import Trainer
from energizer.acquisition_functions import entropy, expected_entropy
from energizer.data.datamodule import ActiveDataModuleWithIndex
from energizer.query_strategies.strategies import RandomArchorPointsStrategy

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# renames "label" to "labels"
collator = DataCollatorWithPadding(
    tokenizer=tokenizer, padding=True, return_tensors="pt"
)

# load dataset
dataset = load_dataset("pietrolesci/ag_news", "concat")

# tokenize
dataset = dataset.map(lambda ex: tokenizer(ex["text"]), batched=True)
columns_to_keep = ["label", "input_ids", "token_type_ids", "attention_mask"]

# train-val split and record datasets
train_set, test_set = dataset["train"], dataset["test"]
_split = train_set.train_test_split(0.3)
_, val_set = _split["train"], _split["test"]

labels = train_set.features["label"].names
num_classes = len(labels)

# create dataloaders
batch_size = BATCH_SIZE
eval_batch_size = EVAL_BATCH_SIZE  # this is use when evaluating on the pool too
train_dl = DataLoader(
    train_set.with_format(columns=columns_to_keep),
    batch_size=batch_size,
    collate_fn=collator,
    num_workers=2,
)
val_dl = DataLoader(
    val_set.with_format(columns=columns_to_keep),
    batch_size=eval_batch_size,
    collate_fn=collator,
    num_workers=2,
)
test_dl = DataLoader(
    test_set.with_format(columns=columns_to_keep),
    batch_size=eval_batch_size,
    collate_fn=collator,
    num_workers=2,
)

Reusing dataset ag_news (/home/pl487/.cache/huggingface/datasets/pietrolesci___ag_news/concat/1.0.0/5ee6e111adc7a901ca734b79fbebff09d9dba91722387a794efff8d9c178a6a3)


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

In [6]:
class TransformerModel(LightningModule):
    def __init__(
        self,
        model_name: str,
        num_classes: int,
        learning_rate: float = 0.00001,
        num_warmup_steps: int = 50,
    ) -> None:
        super().__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=self.num_classes,
        )
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        for stage in ("train", "val", "test"):
            metrics = MetricCollection(
                {
                    "accuracy": Accuracy(),
                    "precision_macro": Precision(
                        num_classes=num_classes, average="macro"
                    ),
                    "recall_macro": Recall(num_classes=num_classes, average="macro"),
                    "f1_macro": F1Score(num_classes=num_classes, average="macro"),
                    "f1_micro": F1Score(num_classes=num_classes, average="micro"),
                }
            )
            setattr(self, f"{stage}_metrics", metrics)

    def common_step(self, batch: Any, stage: str):
        """Outputs loss and logits, logs loss and metrics."""
        out = self(batch)
        logits, loss = out.logits, out.loss
        self.log(f"{stage}/loss", loss)

        metrics = getattr(self, f"{stage}_metrics")(logits, batch["labels"])
        self.log_dict(metrics)

        return loss

    def forward(self, batch) -> torch.Tensor:
        return self.model(**batch)

    def training_step(
        self, batch: Any, batch_idx: int = 0, optimizer_idx: int = 0
    ) -> Dict[str, Any]:
        return self.common_step(batch, "train")

    def validation_step(self, batch: Any, batch_idx: int = 0) -> Dict[str, Any]:
        return self.common_step(batch, "val")

    def test_step(self, batch: Any, batch_idx: int = 0) -> Dict[str, Any]:
        return self.common_step(batch, "test")

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer = AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=self.learning_rate,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": get_constant_schedule_with_warmup(
                    optimizer=optimizer, num_warmup_steps=self.num_warmup_steps
                ),
                "monitor": "val/loss",
                "frequency": 1,
                "interval": "step",
            },
        }

## Active fit

In [7]:
model = TransformerModel(
    model_name=MODEL_NAME, num_classes=num_classes, learning_rate=LEARNING_RATE
)

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w

### Random strategy

In [None]:
random_strategy = RandomStrategy(deepcopy(model))

seed_everything(1994)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=random_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)
random_df = results.to_pandas()
random_df

### AccumulatorStrategy

In [None]:
class EntropyStrategy(AccumulatorStrategy):
    """A implememntation of the `Entropy` active learning strategy."""

    def get_inputs_from_batch(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
        batch.pop("labels")
        return batch

    def pool_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor:
        logits = self(batch).logits
        return entropy(logits)

In [None]:
entropy_strategy = EntropyStrategy(deepcopy(model))

seed_everything(1994)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=entropy_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)
entropy_df = results.to_pandas()
entropy_df

### AnchorPointsStrategy

In [8]:
class MyRandomArchorPointsStrategy(RandomArchorPointsStrategy):
    def get_search_query_from_batch(self, batch: Any) -> Tensor:
        return batch["input_ids"]

In [12]:
random_anchor_points_strategy = RandomArchorPointsStrategy(deepcopy(model), 10)

datamodule = ActiveDataModuleWithIndex(
    train_dataloader=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
    faiss_index_path="all-mpnet-base-v2_ag-news_train.faiss",
)

seed_everything(1994)
trainer = Trainer(**trainer_kwargs)
results = trainer.active_fit(
    model=random_anchor_points_strategy,
    datamodule=datamodule,
)
rap_df = results.to_pandas()
rap_df

Global seed set to 1994
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
[1;32m[2022-09-11 23:15:10] energizer/INFO[0m ~ [1;33mtrainer:269[0m$ [37mTrainer: trainer active_fit stage[0m
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                          | Params
----------------------------------------------------------------
0 | model         | BertForSequenceClassification | 4.4 M 
1 | train_metrics | MetricCollection              | 0     
2 | val_metrics   | MetricCollection              | 0     
3 | test_metrics  | MetricCollection              | 0     
----------------------------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total params
17.546    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  value = torch.tensor(value, device=self.device)
[1;36m[2022-09-11 23:15:11] energizer/DEBUG[0m ~ [1;33mtrainer:464[0m$ [37mUsing `RandomArchorPointsStrategy`[0m


-------------------------Labelling Iteration 0--------------------------


[1;36m[2022-09-11 23:15:11] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[1;36m[2022-09-11 23:15:12] energizer/DEBUG[0m ~ [1;33mtrainer:464[0m$ [37mUsing `RandomArchorPointsStrategy`[0m
  rank_zero_warn(
[1;32m[2022-09-11 23:15:12] energizer/INFO[0m ~ [1;33mactive_learning_loop:193[0m$ [37mQueried 1 instance[0m
[1;36m[2022-09-11 23:15:12] energizer/DEBUG[0m ~ [1;33mdatamodule:322[0m$ [37mUpdating `faiss_index`[0m
[1;32m[2022-09-11 23:15:12] energizer/INFO[0m ~ [1;33mactive_learning_loop:281[0m$ [37mAnnotated 1 instances[0m
[1;32m[2022-09-11 23:15:12] energizer/INFO[0m ~ [1;33mactive_learning_loop:282[0m$ [37mNew data statistics
num_pool_batches: 235
num_train_batches: 1
pool_size: 119999
total_data_size: 120000
train_size: 1
[0m


-------------------------Labelling Iteration 1--------------------------


[1;36m[2022-09-11 23:15:12] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m
[1;36m[2022-09-11 23:15:12] energizer/DEBUG[0m ~ [1;33mactive_learning_loop:250[0m$ [37mTransformerModel state dict has been re-initialized[0m
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;36m[2022-09-11 23:15:15] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m


Testing: 0it [00:00, ?it/s]

[1;36m[2022-09-11 23:15:16] energizer/DEBUG[0m ~ [1;33mtrainer:464[0m$ [37mUsing `RandomArchorPointsStrategy`[0m
[1;36m[2022-09-11 23:15:16] energizer/DEBUG[0m ~ [1;33mdatamodule:306[0m$ [37mSearching `faiss_index`[0m
[1;32m[2022-09-11 23:15:16] energizer/INFO[0m ~ [1;33mactive_learning_loop:193[0m$ [37mQueried 1 instance[0m
[1;36m[2022-09-11 23:15:16] energizer/DEBUG[0m ~ [1;33mdatamodule:322[0m$ [37mUpdating `faiss_index`[0m
[1;32m[2022-09-11 23:15:16] energizer/INFO[0m ~ [1;33mactive_learning_loop:281[0m$ [37mAnnotated 1 instances[0m
[1;32m[2022-09-11 23:15:16] energizer/INFO[0m ~ [1;33mactive_learning_loop:282[0m$ [37mNew data statistics
num_pool_batches: 235
num_train_batches: 1
pool_size: 119998
total_data_size: 120000
train_size: 2
[0m


-------------------------Labelling Iteration 2--------------------------


[1;36m[2022-09-11 23:15:16] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m
[1;36m[2022-09-11 23:15:16] energizer/DEBUG[0m ~ [1;33mactive_learning_loop:250[0m$ [37mTransformerModel state dict has been re-initialized[0m


Training: 1it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;36m[2022-09-11 23:15:19] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m


Testing: 0it [00:00, ?it/s]

[1;36m[2022-09-11 23:15:20] energizer/DEBUG[0m ~ [1;33mtrainer:464[0m$ [37mUsing `RandomArchorPointsStrategy`[0m
[1;36m[2022-09-11 23:15:20] energizer/DEBUG[0m ~ [1;33mdatamodule:306[0m$ [37mSearching `faiss_index`[0m
[1;32m[2022-09-11 23:15:20] energizer/INFO[0m ~ [1;33mactive_learning_loop:193[0m$ [37mQueried 1 instance[0m
[1;36m[2022-09-11 23:15:20] energizer/DEBUG[0m ~ [1;33mdatamodule:322[0m$ [37mUpdating `faiss_index`[0m
[1;32m[2022-09-11 23:15:20] energizer/INFO[0m ~ [1;33mactive_learning_loop:281[0m$ [37mAnnotated 1 instances[0m
[1;32m[2022-09-11 23:15:20] energizer/INFO[0m ~ [1;33mactive_learning_loop:282[0m$ [37mNew data statistics
num_pool_batches: 235
num_train_batches: 1
pool_size: 119997
total_data_size: 120000
train_size: 3
[0m


-------------------------Labelling Iteration 3--------------------------


[1;36m[2022-09-11 23:15:20] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m
[1;36m[2022-09-11 23:15:20] energizer/DEBUG[0m ~ [1;33mactive_learning_loop:250[0m$ [37mTransformerModel state dict has been re-initialized[0m


Training: 1it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;36m[2022-09-11 23:15:23] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m


Testing: 0it [00:00, ?it/s]

[1;36m[2022-09-11 23:15:24] energizer/DEBUG[0m ~ [1;33mtrainer:464[0m$ [37mUsing `RandomArchorPointsStrategy`[0m
[1;36m[2022-09-11 23:15:24] energizer/DEBUG[0m ~ [1;33mdatamodule:306[0m$ [37mSearching `faiss_index`[0m
[1;32m[2022-09-11 23:15:24] energizer/INFO[0m ~ [1;33mactive_learning_loop:193[0m$ [37mQueried 1 instance[0m
[1;36m[2022-09-11 23:15:24] energizer/DEBUG[0m ~ [1;33mdatamodule:322[0m$ [37mUpdating `faiss_index`[0m
[1;32m[2022-09-11 23:15:25] energizer/INFO[0m ~ [1;33mactive_learning_loop:281[0m$ [37mAnnotated 1 instances[0m
[1;32m[2022-09-11 23:15:25] energizer/INFO[0m ~ [1;33mactive_learning_loop:282[0m$ [37mNew data statistics
num_pool_batches: 235
num_train_batches: 1
pool_size: 119996
total_data_size: 120000
train_size: 4
[0m


-------------------------Labelling Iteration 4--------------------------


[1;36m[2022-09-11 23:15:25] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m
[1;36m[2022-09-11 23:15:25] energizer/DEBUG[0m ~ [1;33mactive_learning_loop:250[0m$ [37mTransformerModel state dict has been re-initialized[0m


Training: 1it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;36m[2022-09-11 23:15:27] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m


Testing: 0it [00:00, ?it/s]

[1;36m[2022-09-11 23:15:28] energizer/DEBUG[0m ~ [1;33mtrainer:464[0m$ [37mUsing `RandomArchorPointsStrategy`[0m
[1;36m[2022-09-11 23:15:28] energizer/DEBUG[0m ~ [1;33mdatamodule:306[0m$ [37mSearching `faiss_index`[0m
[1;32m[2022-09-11 23:15:29] energizer/INFO[0m ~ [1;33mactive_learning_loop:193[0m$ [37mQueried 1 instance[0m
[1;36m[2022-09-11 23:15:29] energizer/DEBUG[0m ~ [1;33mdatamodule:322[0m$ [37mUpdating `faiss_index`[0m
[1;32m[2022-09-11 23:15:29] energizer/INFO[0m ~ [1;33mactive_learning_loop:281[0m$ [37mAnnotated 1 instances[0m
[1;32m[2022-09-11 23:15:29] energizer/INFO[0m ~ [1;33mactive_learning_loop:282[0m$ [37mNew data statistics
num_pool_batches: 235
num_train_batches: 1
pool_size: 119995
total_data_size: 120000
train_size: 5
[0m
[1;36m[2022-09-11 23:15:29] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m
[1;36m[2022-09-11 23:15:29] energizer/DEBUG[0m ~ [1;33mactive_learning_loop:250[0m$ [3

-----------------------------Last fit_loop------------------------------


Training: 1it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;36m[2022-09-11 23:15:32] energizer/DEBUG[0m ~ [1;33mtrainer:468[0m$ [37mUsing underlying `TransformerModel`[0m


Testing: 0it [00:00, ?it/s]

[1;36m[2022-09-11 23:15:33] energizer/DEBUG[0m ~ [1;33mtrainer:464[0m$ [37mUsing `RandomArchorPointsStrategy`[0m


Unnamed: 0,train_size,test/loss,accuracy,f1_macro,f1_micro,precision_macro,recall_macro
0,0,1.401348,0.197368,0.121583,0.197368,0.104766,0.197283
1,1,1.401223,0.196842,0.122097,0.196842,0.104648,0.196799
2,2,1.401255,0.198421,0.121753,0.198421,0.105082,0.198328
3,3,1.401456,0.201974,0.120184,0.201974,0.10558,0.201663
4,4,1.401842,0.208553,0.117121,0.208553,0.106078,0.208357
5,5,1.402168,0.213947,0.115544,0.213947,0.106582,0.213659


----

In [None]:
plt.plot(random_df["train_size"], random_df["accuracy"], label="random")
plt.plot(entropy_df["train_size"], entropy_df["accuracy"], label="entropy")
plt.plot(rap_df["train_size"], rap_df["accuracy"], label="random anchors")
plt.legend()
plt.show()

In [None]:
random_df["strategy"] = "random"
entropy_df["strategy"] = "entropy"
rap_df["strategy"] = "random_anchors"
results = pd.concat([random_df, entropy_df, rap_df], ignore_index=False, axis=0)

In [None]:
# results.to_parquet("results_al.parquet", index=False)
# with open("results_al_metadata.json", "w") as fl:
#     json.dump(trainer_kwargs, fl)