In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from energizer.datastores import PandasDataStoreForSequenceClassification
from src.strategies import RandomStrategy

In [None]:
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
ds_dict = load_dataset("pietrolesci/wiki_toxic_indexed").map(
    lambda ex: tokenizer(ex["comment_text"]), batched=True, num_proc=4
)

In [None]:
ds = PandasDataStoreForSequenceClassification()
ds.from_dataset_dict(
    ds_dict,
    input_names=["input_ids", "attention_mask"],
    target_name="labels",
    tokenizer=tokenizer,
    uid_name="uid",
)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

In [None]:
estimator = RandomStrategy(model=model, accelerator="gpu")

In [None]:
ds.prepare_for_loading()

In [None]:
r = estimator.fit(train_loader=ds.test_loader(), max_epochs=1, limit_train_batches=2)

In [None]:
r = estimator.active_fit(
    ds, query_size=50, max_epochs=1, limit_test_batches=2, max_rounds=2
)

In [None]:
datamodule = ClassificationActiveDataModule.from_dataset_dict(
    dataset_dict, tokenizer=tokenizer
)

In [None]:
datamodule.load_index(meta["hnsw_index_path"], embedding_dim=meta["embedding_dim"])

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    meta["name_or_path"],
    id2label=datamodule.id2label,
    label2id=datamodule.label2id,
    num_labels=len(datamodule.labels),
)
active_estimator = SimilaritySearchStrategyForSequenceClassification(
    model=model, seed=42
)

In [None]:
active_estimator.active_fit(
    max_rounds=2,
    query_size=100,
    active_datamodule=datamodule,
    limit_test_batches=10,
)

In [None]:
active_estimator.progress_tracker.budget_tracker

In [None]:
datamodule.train_size()

In [None]:
loader = datamodule.train_loader()
batch = next(iter(loader))
_ = batch.pop("on_cpu")
batch_size = loader.batch_size

In [None]:
import torch

In [None]:
loss = model(**batch).loss

In [None]:
grads = torch.autograd.grad(loss, list(model.parameters()))

In [None]:
def compute_grad(model, input_ids, attn_mask, target):
    input_ids = input_ids.unsqueeze(0)  # prepend batch dimension for processing
    attn_mask = attn_mask.unsqueeze(0)  # prepend batch dimension for processing
    target = target.unsqueeze(0)
    loss = model(input_ids=input_ids, attention_mask=attn_mask, labels=target).loss
    return torch.autograd.grad(loss, list(model.parameters()))

In [None]:
input_ids, attention_mask, target = (
    batch["input_ids"],
    batch["attention_mask"],
    batch["labels"],
)

In [None]:
def select(i):
    return input_ids[i], attention_mask[i], target[i]

In [None]:
norms = np.array(
    [
        [g.norm(2).item() for g in compute_grad(model, *select(i))]
        for i in range(batch_size)
    ]
)

In [None]:
norms.shape

In [None]:
from functorch import grad, make_functional_with_buffers, vmap

fmodel, params, buffers = make_functional_with_buffers(model)

In [None]:
def compute_loss_stateless_model(
    fmodel, params, buffers, input_ids, att_mask, label
):
    input_ids = input_ids.unsqueeze(0)
    att_mask = att_mask.unsqueeze(0)
    label = label.unsqueeze(0)

    return fmodel(
        params,
        buffers,
        input_ids=input_ids,
        attention_mask=att_mask,
        labels=label,
    ).loss

In [None]:
compute_loss_stateless_model(fmodel, params, buffers, *select(0))

In [None]:
ft_compute_grad = grad(compute_loss_stateless_model, argnums=1)

In [None]:
ft_compute_grad(fmodel, params, buffers, *select(0))[0].requires_grad

In [None]:
%%timeit
fnorms = np.array(
    [
        [
            g.norm(2).item()
            for g in ft_compute_grad(fmodel, params, buffers, *select(i))
        ]
        for i in range(batch_size)
    ]
)

In [None]:
def compute_norm(fmodel, params, buffers, input_ids, attention_mask, target):
    grads = ft_compute_grad(
        fmodel, params, buffers, input_ids, attention_mask, target
    )
    return tuple(g.norm() for g in grads)

In [None]:
ft_compute_sample_grad = vmap(
    compute_norm, in_dims=(None, None, None, 0, 0, 0), randomness="same"
)

In [None]:
fnorms_vmap = torch.stack(
    ft_compute_sample_grad(
        fmodel, params, buffers, input_ids, attention_mask, target
    )
).T

In [None]:
fnorms_vmap

In [None]:
fnorms

In [None]:
# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:
for per_sample_grad, ft_per_sample_grad in zip(
    per_sample_grads, ft_per_sample_grads
):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

In [None]:
ft_per_sample_grad

In [None]:
import time
from pathlib import Path

import pandas as pd
import srsly
from datasets import load_from_disk
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from src.data.active_datamodule import ActiveClassificationDataModule
from src.data.datamodule import ClassificationDataModule
from src.enums import SpecialKeys
from src.estimator import Estimator
from src.huggingface import (
    EstimatorForSequenceClassification,
    UncertaintyBasedStrategyForSequenceClassification,
)

In [None]:
data_path = Path("../data/prepared/ag_news")
dataset_dict = load_from_disk(data_path)
metadata = srsly.read_yaml(data_path / "metadata.yaml")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(metadata["name_or_path"])

In [None]:
datamodule = ClassificationDataModule.from_dataset_dict(
    dataset_dict, tokenizer=tokenizer
)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    metadata["name_or_path"],
    num_labels=len(datamodule.labels),
    id2label=datamodule.id2label,
    label2id=datamodule.label2id,
)

In [None]:
estimator = EstimatorForSequenceClassification(model)

In [None]:
out = estimator.fit(
    train_loader=datamodule.train_loader(),
    validation_loader=datamodule.validation_loader(),
    limit_train_batches=10,
    limit_validation_batches=10,
    max_epochs=1,
)

In [None]:
active_estimator = UncertaintyBasedStrategyForSequenceClassification(
    model, score_fn="margin_confidence"
)

In [None]:
out = active_estimator.fit(
    train_loader=datamodule.train_loader(),
    validation_loader=datamodule.validation_loader(),
    limit_train_batches=10,
    limit_validation_batches=10,
)

In [None]:
active_datamodule = ActiveClassificationDataModule.from_dataset_dict(
    dataset_dict,
    tokenizer=tokenizer,
)

In [None]:
active_out = active_estimator.active_fit(
    active_datamodule=active_datamodule,
    max_rounds=3,
    query_size=50,
    validation_perc=0.3,
    fit_kwargs={
        "max_epochs": 3,
        "limit_train_batches": 3,
        "limit_validation_batches": 3,
    },
    test_kwargs={"limit_batches": 3},
    pool_kwargs={"limit_batches": 3},
)

In [None]:
active_datamodule.save_labelled_dataset("results")

In [None]:
df = pd.read_parquet("results/labelled_dataset.parquet")

In [None]:
df