In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("..")
import warnings

warnings.filterwarnings('ignore')
import numpy as np
import torch
import torch.special as sp
from datasets import load_dataset
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import Callback
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding

from energizer.data import ActiveDataModule, ActiveDataset
from energizer.inference import Deterministic
from energizer.loops import ActiveLearningLoop
from energizer.strategies import LeastConfidenceStrategy

In [None]:
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
ds = load_dataset("pietrolesci/ag_news", name="concat")
ds = ds.map(lambda ex: tokenizer(ex["text"], return_token_type_ids=False), batched=True)
ds = ds.with_format(columns=["input_ids", "attention_mask", "label"])

In [None]:
class Model(LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=4)
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, batch):
        return self.backbone(**batch).logits

    def step(self, batch, *args, **kwargs):
        y = batch.pop("labels")
        y_hat = self(batch)
        return self.loss(y_hat, y)

    def training_step(self, batch, *args, **kwargs):
        # self.print("TRAIN")
        loss = self.step(batch, *args, **kwargs)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, *args, **kwargs):
        # self.print("VAL")
        loss = self.step(batch, *args, **kwargs)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, *args, **kwargs):
        # self.print("TEST")
        loss = self.step(batch, *args, **kwargs)
        self.log("test_loss", loss)
        return loss

    def predict_step(self, batch, *args, **kwargs):
        # self.print("PREDICT")
        return self(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def on_pool_batch_end(self, *args, **kwargs):
        self.print("PRINTING FROM MODULE")


class ActiveLearninigCallback(Callback):
    def __init__(self):
        super().__init__()

    def on_pool_batch_end(self, *args, **kwargs):
        print("FROM THE CALLBACK")

In [None]:
model = Model()

In [None]:
def get_dm():
    train_ds = ds["train"].select(list(range(1_000)))
    return ActiveDataModule(
        num_classes=4,
        train_dataset=train_ds,
        # val_dataset=ds["test"],
        initial_labels=np.random.choice(list(range(len(train_ds))), size=100, replace=False).tolist(),
        val_split=0.5,
        test_dataset=ds["test"].select(range(1_000)),
        # predict_dataset=ds["test"].select(range(15)),
        batch_size=32,
        seed=42,
        collate_fn=DataCollatorWithPadding(tokenizer),
    )

In [None]:
seed_everything(1111)

# get data
dm = get_dm()
trainer = Trainer(
    max_epochs=100,
    enable_progress_bar=True,
    enable_model_summary=True,
    callbacks=ActiveLearninigCallback(),
)

# define active learning loop and attach to trainer
active_learning_loop = ActiveLearningLoop(
    al_strategy=LeastConfidenceStrategy(inference_module=Deterministic()),
    query_size=2,
    reset_weights=True,
    n_epochs_between_labelling=3,
)
active_learning_loop.connect(trainer)
trainer.fit_loop = active_learning_loop

# fit model with active learning
trainer.fit(model, datamodule=dm)