In [34]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.special as sp
from torch.utils.data import Dataset
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning

from energizer.data import ActiveDataset, ActiveDataModule
from energizer.loops import ActiveLearningLoop
from energizer.strategies import LeastConfidenceStrategy
from energizer.inference import Deterministic

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
ds = load_dataset("pietrolesci/ag_news", name="concat")
ds = ds.map(lambda ex: tokenizer(ex["text"], return_token_type_ids=False), batched=True)
ds = ds.with_format(columns=["input_ids", "attention_mask", "label"])

Reusing dataset ag_news (/Users/pietrolesci/.cache/huggingface/datasets/pietrolesci___ag_news/concat/1.0.0/5ee6e111adc7a901ca734b79fbebff09d9dba91722387a794efff8d9c178a6a3)
100%|██████████| 2/2 [00:00<00:00, 218.80it/s]
Loading cached processed dataset at /Users/pietrolesci/.cache/huggingface/datasets/pietrolesci___ag_news/concat/1.0.0/5ee6e111adc7a901ca734b79fbebff09d9dba91722387a794efff8d9c178a6a3/cache-000bee98bdbc4cf0.arrow
Loading cached processed dataset at /Users/pietrolesci/.cache/huggingface/datasets/pietrolesci___ag_news/concat/1.0.0/5ee6e111adc7a901ca734b79fbebff09d9dba91722387a794efff8d9c178a6a3/cache-2492e1c9d423a80c.arrow


In [45]:
class Model(LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=4)
        self.loss = torch.nn.CrossEntropyLoss()
    
    def forward(self, batch):
        return self.backbone(**batch).logits

    def step(self, batch, *args, **kwargs):
        y = batch.pop("labels")        
        y_hat = self(batch)
        return self.loss(y_hat, y)
    
    def training_step(self, batch, *args, **kwargs):
        # self.print("TRAIN")
        loss = self.step(batch, *args, **kwargs)
        self.log("train_loss", loss)

    def validation_step(self, batch, *args, **kwargs):
        # self.print("VAL")
        loss = self.step(batch, *args, **kwargs)
        self.log("val_loss", loss)

    def test_step(self, batch, *args, **kwargs):
        # self.print("TEST")
        loss = self.step(batch, *args, **kwargs)
        self.log("test_loss", loss)

    def predict_step(self, batch, *args, **kwargs):
        # self.print("PREDICT")
        return self(batch)
    
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01)

    def on_pool_batch_end(self, *args, **kwargs):
        self.print("PRINTING FROM MODULE")


class ActiveLearninigCallback(Callback):

    def __init__(self):
        super().__init__()

    def on_pool_batch_end(self, *args, **kwargs):
        print("FROM THE CALLBACK")

In [15]:
model = Model()

Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initia

In [17]:
def get_dm():
    train_ds = ds["train"].select(list(range(10)))
    return ActiveDataModule(
        num_classes=4,
        train_dataset=train_ds,
        # val_dataset=ds["test"],
        initial_labels=np.random.choice(list(range(len(train_ds))), size=4, replace=False).tolist(),
        val_split=0.5,
        test_dataset=ds["test"].select(range(5)),
        # predict_dataset=ds["test"].select(range(15)),
        batch_size=32,
        seed=42,
        collate_fn=DataCollatorWithPadding(tokenizer),
    )


In [47]:
seed_everything(1111)
dm = get_dm()
trainer = Trainer(
    max_epochs=100, 
    enable_progress_bar=False, 
    enable_model_summary=None, 
    callbacks=ActiveLearninigCallback(),
)
active_learning_loop = ActiveLearningLoop(
    strategy=LeastConfidenceStrategy(inference_module=Deterministic()),
    query_size=2,
)
active_learning_loop.connect(trainer)
trainer.fit_loop = active_learning_loop
trainer.fit(model, datamodule=dm)


Global seed set to 1111
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Global seed set to 1111


Validation dataset is not specfied and `val_split == 0.5`, therefore at each training step 0.5 of the labelled data will be added to the validation dataset.

EPOCH
Active learning dataset: ActiveDataset({
    original_dataset_size: 10,
    train_size: 2,
    val_size: 2,
    pool_size: 6,
    base_class: <class 'datasets.arrow_dataset.Dataset'>,
})
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 1.3116883039474487}
--------------------------------------------------------------------------------
HERE
PRINTING FROM MODULE
FROM THE CALLBACK
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------

EPOCH
Active learning dataset: ActiveDataset({
    original_dataset_size: 10,
    train_size: 3,
    val_size: 3,
    pool_size: 4,
    base_class: <class 'datasets.arrow_dataset.Dataset'>,
}

In [24]:
trainer.call_hook("on_pool_batch_end")

In [None]:
a = np.array(list(range(10)) * 2) 
a[a == 0] = len(a) + 1e6
np.random.shuffle(a)
ids = np.argsort(a)
a[a == len(a) + 1e6] = 0

In [None]:
a[ids]

In [None]:
ids

In [None]:
dm._active_dataset._pool_to_oracle(4)

In [None]:
dm.pool_dataset.indices

In [None]:
[0.1, 0.2, 0.3, 0.4], [0, 1, 2, 3]

In [None]:
[.1, .2], [0, 1]

In [None]:
[0, 1][1]

In [None]:
[0.3, 0.4], [0, 1]

In [None]:
[0, 1][1]

In [None]:
1 + (0 * 2)