In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2"
BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256
LEARNING_RATE = 0.0001
MAX_LABELLING_ITERS = 2

## Pre

In [3]:
import os
from copy import deepcopy
from typing import Any, Dict, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn.functional as F
from datasets import load_dataset
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer as PLTrainer
from pytorch_lightning import seed_everything
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, F1Score, MetricCollection, Precision, Recall
from transformers import (
    AdamW,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    get_constant_schedule_with_warmup,
)

from energizer import AccumulatorStrategy, RandomStrategy, Trainer
from energizer.acquisition_functions import entropy, expected_entropy
from energizer.data.datamodule import ActiveDataModuleWithIndex

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# renames "label" to "labels"
collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, return_tensors="pt")

# load dataset
dataset = load_dataset("pietrolesci/ag_news", "concat")

# tokenize
dataset = dataset.map(lambda ex: tokenizer(ex["text"]), batched=True)
columns_to_keep = ["label", "input_ids", "token_type_ids", "attention_mask"]

# train-val split and record datasets
train_set, test_set = dataset["train"], dataset["test"]
_split = train_set.train_test_split(0.3)
_, val_set = _split["train"], _split["test"]

labels = train_set.features["label"].names
num_classes = len(labels)

# create dataloaders
batch_size = BATCH_SIZE
eval_batch_size = EVAL_BATCH_SIZE  # this is use when evaluating on the pool too
train_dl = DataLoader(
    train_set.with_format(columns=columns_to_keep),
    batch_size=batch_size,
    collate_fn=collator,
    num_workers=2,
)
val_dl = DataLoader(
    val_set.with_format(columns=columns_to_keep),
    batch_size=eval_batch_size,
    collate_fn=collator,
    num_workers=2,
)
test_dl = DataLoader(
    test_set.with_format(columns=columns_to_keep),
    batch_size=eval_batch_size,
    collate_fn=collator,
    num_workers=2,
)

Reusing dataset ag_news (/home/pl487/.cache/huggingface/datasets/pietrolesci___ag_news/concat/1.0.0/5ee6e111adc7a901ca734b79fbebff09d9dba91722387a794efff8d9c178a6a3)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___ag_news/concat/1.0.0/5ee6e111adc7a901ca734b79fbebff09d9dba91722387a794efff8d9c178a6a3/cache-244c58eb4fe55230.arrow
Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___ag_news/concat/1.0.0/5ee6e111adc7a901ca734b79fbebff09d9dba91722387a794efff8d9c178a6a3/cache-561b1c44fc31a694.arrow


In [5]:
len(train_dl.dataset)

120000

In [6]:
datamodule = ActiveDataModuleWithIndex(
    train_dataloader=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
    faiss_index_path="train_ag_news.faiss",
)

In [7]:
class TransformerModel(LightningModule):
    def __init__(
        self,
        model_name: str,
        num_classes: int,
        learning_rate: float = 0.00001,
        num_warmup_steps: int = 50,
    ) -> None:
        super().__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=self.num_classes,
        )
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        for stage in ("train", "val", "test"):
            metrics = MetricCollection(
                {
                    "accuracy": Accuracy(),
                    "precision_macro": Precision(num_classes=num_classes, average="macro"),
                    "precision_micro": Precision(num_classes=num_classes, average="micro"),
                    "recall_macro": Recall(num_classes=num_classes, average="macro"),
                    "recall_micro": Recall(num_classes=num_classes, average="micro"),
                    "f1_macro": F1Score(num_classes=num_classes, average="macro"),
                    "f1_micro": F1Score(num_classes=num_classes, average="micro"),
                }
            )
            setattr(self, f"{stage}_metrics", metrics)

    def common_step(self, batch: Any, stage: str):
        """Outputs loss and logits, logs loss and metrics."""
        out = self(batch)
        logits, loss = out.logits, out.loss
        self.log(f"{stage}/loss", loss)

        metrics = getattr(self, f"{stage}_metrics")(logits, batch["labels"])
        self.log_dict(metrics)

        return loss

    def forward(self, batch) -> torch.Tensor:
        return self.model(**batch)

    def training_step(self, batch: Any, batch_idx: int = 0, optimizer_idx: int = 0) -> Dict[str, Any]:
        return self.common_step(batch, "train")

    def validation_step(self, batch: Any, batch_idx: int = 0) -> Dict[str, Any]:
        return self.common_step(batch, "val")

    def test_step(self, batch: Any, batch_idx: int = 0) -> Dict[str, Any]:
        return self.common_step(batch, "test")

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer = AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": get_constant_schedule_with_warmup(
                    optimizer=optimizer, num_warmup_steps=self.num_warmup_steps
                ),
                "monitor": "val/loss",
                "frequency": 1,
                "interval": "step",
            },
        }

## Active fit

In [8]:
model = TransformerModel(model_name=MODEL_NAME, num_classes=num_classes, learning_rate=LEARNING_RATE)

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w

### Random strategy

In [None]:
random_strategy = RandomStrategy(deepcopy(model))

seed_everything(1994)

trainer = Trainer(
    query_size=50,
    max_epochs=3,
    max_labelling_epochs=MAX_LABELLING_ITERS,
    # total_budget=5,
    test_after_labelling=True,
    accelerator="gpu",
    # for testing purposes
    # limit_train_batches=10,
    limit_val_batches=1,
    # limit_test_batches=10,
    # limit_pool_batches=10,
    # log_every_n_steps=1,
)

results = trainer.active_fit(
    model=random_strategy,
    # train_dataloaders=train_dl,
    # val_dataloaders=val_dl,
    # test_dataloaders=test_dl,
    datamodule=datamodule,
)

In [None]:
random_df = results.to_pandas()
random_df

### AccumulatorStrategy

In [None]:
class EntropyStrategy(AccumulatorStrategy):
    """A implememntation of the `Entropy` active learning strategy."""

    def get_inputs_from_batch(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
        batch.pop("labels")
        return batch

    def pool_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor:
        logits = self(batch).logits
        return entropy(logits)

In [None]:
entropy_strategy = EntropyStrategy(deepcopy(model))

seed_everything(1994)

trainer = Trainer(
    query_size=50,
    max_epochs=3,
    max_labelling_epochs=MAX_LABELLING_ITERS,
    # total_budget=5,
    test_after_labelling=True,
    accelerator="gpu",
    # for testing purposes
    # limit_train_batches=10,
    limit_val_batches=1,
    limit_test_batches=10,
    limit_pool_batches=10,
    # log_every_n_steps=1,
)

results = trainer.active_fit(
    model=entropy_strategy,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
)

In [None]:
entropy_df = pd.DataFrame(
    data=[(l.data_stats["train_size"], *l.test_outputs[0].values()) for l in results],
    columns=("train_size", *results[0].test_outputs[0].keys()),
)
entropy_df

### AnchorPointsStrategy

In [9]:
from energizer.query_strategies.base import RandomArchorPointsStrategy

In [10]:
class MyRandomArchorPointsStrategy(RandomArchorPointsStrategy):
    def get_search_query_from_batch(self, batch: Any) -> Tensor:
        return batch["input_ids"]

In [11]:
random_anchor_points_strategy = MyRandomArchorPointsStrategy(deepcopy(model), 10)
datamodule = ActiveDataModuleWithIndex(
    train_dataloader=train_dl,
    val_dataloaders=val_dl,
    test_dataloaders=test_dl,
    faiss_index_path="train_ag_news.faiss",
)
seed_everything(1994)

trainer = Trainer(
    query_size=50,
    max_epochs=3,
    max_labelling_epochs=MAX_LABELLING_ITERS,
    # total_budget=5,
    test_after_labelling=True,
    accelerator="gpu",
    # for testing purposes
    # limit_train_batches=10,
    limit_val_batches=1,
    limit_test_batches=10,
    limit_pool_batches=10,
    # log_every_n_steps=1,
)

results = trainer.active_fit(
    model=random_anchor_points_strategy,
    datamodule=datamodule,
)

Global seed set to 1994
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
[1;36m[2022-08-27 21:25:07] energizer/DETAIL[0m ~ [1;33mtrainer:227[0m$ Trainer: trainer active_fit stage[0m
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                          | Params
----------------------------------------------------------------
0 | model         | BertForSequenceClassification | 4.4 M 
1 | train_metrics | MetricCollection              | 0     
2 | val_metrics   | MetricCollection              | 0     
3 | test_metrics  | MetricCollection              | 0     
----------------------------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total params
17.546    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
[1;32m[2022-08-27 21:25:11] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `MyRandomArchorPointsStrategy`[0m


-------------------------Labelling Iteration 0--------------------------


[1;32m[2022-08-27 21:25:11] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `TransformerModel`[0m
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 21:25:11] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `MyRandomArchorPointsStrategy`[0m
  rank_zero_warn(
[1;32m[2022-08-27 21:25:11] energizer/INFO[0m ~ [1;33mactive_learning_loop:195[0m$ [37mQueried 50 instance[0m
[1;32m[2022-08-27 21:25:11] energizer/INFO[0m ~ [1;33mdatamodule:304[0m$ [37mUpdating `faiss_index`[0m
[1;32m[2022-08-27 21:25:12] energizer/INFO[0m ~ [1;33mactive_learning_loop:283[0m$ [37mAnnotated 50 instances[0m
[1;32m[2022-08-27 21:25:12] energizer/INFO[0m ~ [1;33mactive_learning_loop:284[0m$ [37mNew data statistics
num_pool_batches: 469
num_train_batches: 2
pool_size: 119950
total_data_size: 120000
train_size: 50
[0m


-------------------------Labelling Iteration 1--------------------------


[1;32m[2022-08-27 21:25:12] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `TransformerModel`[0m
[1;32m[2022-08-27 21:25:12] energizer/INFO[0m ~ [1;33mactive_learning_loop:252[0m$ [37mTransformerModel state dict has been re-initialized[0m
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;32m[2022-08-27 21:25:14] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `TransformerModel`[0m


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 21:25:15] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `MyRandomArchorPointsStrategy`[0m
[1;32m[2022-08-27 21:25:15] energizer/INFO[0m ~ [1;33mdatamodule:298[0m$ [37mSearching `faiss_index`[0m
[1;32m[2022-08-27 21:25:19] energizer/INFO[0m ~ [1;33mactive_learning_loop:195[0m$ [37mQueried 500 instance[0m
[1;32m[2022-08-27 21:25:19] energizer/INFO[0m ~ [1;33mdatamodule:304[0m$ [37mUpdating `faiss_index`[0m
[1;32m[2022-08-27 21:25:19] energizer/INFO[0m ~ [1;33mactive_learning_loop:283[0m$ [37mAnnotated 500 instances[0m
[1;32m[2022-08-27 21:25:19] energizer/INFO[0m ~ [1;33mactive_learning_loop:284[0m$ [37mNew data statistics
num_pool_batches: 467
num_train_batches: 18
pool_size: 119450
total_data_size: 120000
train_size: 550
[0m
[1;32m[2022-08-27 21:25:19] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `TransformerModel`[0m
[1;32m[2022-08-27 21:25:19] energizer/INFO[0m ~ [1;33mactive_learning_loop:252[0m$

-----------------------------Last fit_loop------------------------------


  rank_zero_warn(


Training: 2it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
[1;32m[2022-08-27 21:25:24] energizer/INFO[0m ~ [1;33mtrainer:452[0m$ [37mUsing underlying `TransformerModel`[0m


Testing: 0it [00:00, ?it/s]

[1;32m[2022-08-27 21:25:24] energizer/INFO[0m ~ [1;33mtrainer:448[0m$ [37mUsing `MyRandomArchorPointsStrategy`[0m


In [12]:
results.to_pandas()

Unnamed: 0,train_size,test/loss,accuracy,f1_macro,f1_micro,precision_macro,precision_micro,recall_macro,recall_micro
0,0,1.406698,0.160938,0.122527,0.160938,0.188859,0.160938,0.155499,0.160938
1,50,1.405013,0.163281,0.129238,0.163281,0.203633,0.163281,0.157893,0.163281
2,550,1.184808,0.596484,0.54321,0.596484,0.664054,0.596484,0.589746,0.596484
