In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from energizer.datastores import PandasDataStoreForSequenceClassification
from src.strategies import FullSubset, RandomStrategy

In [3]:
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
ds_dict = load_dataset("pietrolesci/wiki_toxic_indexed").map(
    lambda ex: tokenizer(ex["comment_text"]), batched=True, num_proc=4
)

Found cached dataset parquet (/home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--wiki_toxic_indexed-bbeb1b8d65bf4665/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--wiki_toxic_indexed-bbeb1b8d65bf4665/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-617ce17a7b47db48_*_of_00004.arrow
Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--wiki_toxic_indexed-bbeb1b8d65bf4665/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-ef736a700f5331c7_*_of_00004.arrow
Loading cached processed dataset at /home/pl487/.cache/huggingface/datasets/pietrolesci___parquet/pietrolesci--wiki_toxic_indexed-bbeb1b8d65bf4665/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-302176ead9687723_*_of_00004.arrow


In [4]:
ds = PandasDataStoreForSequenceClassification()
ds.from_dataset_dict(
    ds_dict,
    input_names=["input_ids", "attention_mask"],
    target_name="labels",
    tokenizer=tokenizer,
    uid_name="uid",
)

In [5]:
emb_name = "embedding_all-mpnet-base-v2"
ds.add_index(emb_name)

In [6]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

# estimator = RandomStrategy(model=model, accelerator="gpu")
estimator = FullSubset(
    model=model,
    accelerator="gpu",
    num_neighbours=100,
    subset_size=10_000,
    seed=42,
    score_fn="least_confidence",
)

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w

In [7]:
ds.label(list(range(100)), -1)
ds.prepare_for_loading(batch_size=32, eval_batch_size=512)

In [8]:
r = estimator.active_fit(
    ds,
    query_size=50,
    max_epochs=1,
    limit_test_batches=2,
    max_rounds=2,
    limit_pool_batches=2,
)

Completed rounds:   0%|          | 0/3 [00:00<?, ?it/s]

Completed epochs: 0it [00:00, ?it/s]

Epoch 0: 0it [00:00, ?it/s]

Test: 0it [00:00, ?it/s]

Pool: 0it [00:00, ?it/s]

In [17]:
len(ds_dict["train"]), len(ds_dict["validation"]), len(ds.data), len(
    ds_dict["train"]
) + len(ds_dict["validation"])

(127656, 31915, 159571, 159571)

In [63]:
pd.read_parquet(
    "/home/pl487/allset/outputs/debug/imdb/randomguide_2023-05-19T16-28-44/logs/labelled_dataset.parquet"
)

Unnamed: 0,text,labels,uid,embedding_all-mpnet-base-v2,embedding_multi-qa-mpnet-base-dot-v1,embedding_all-MiniLM-L12-v2,input_ids,token_type_ids,attention_mask,is_labelled,is_validation,labelling_round,train_uid,dists
0,I very much looked forward to this movie. Its ...,0,39,"[-0.025794797, 0.050868154, -0.01413649, 0.004...","[-0.24517001, 0.15869692, -0.20793273, 0.20060...","[-0.08997142, -0.037334215, 0.012179131, -0.02...","[101, 1045, 2200, 2172, 2246, 2830, 2000, 2023...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,8,5066.0,0.351936
1,There are only two movies I would give a 1/10 ...,0,72,"[0.010214066, 0.035168514, 0.021841576, -0.034...","[-0.18534045, -0.15982038, -0.06648923, 0.0308...","[-0.042834364, -0.0033234707, -0.008627697, -0...","[101, 2045, 2024, 2069, 2048, 5691, 1045, 2052...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,6,1207.0,0.299353
2,Not only is it a disgustingly made low-budget ...,0,93,"[0.029638596, 0.0747271, -0.020699004, 0.01253...","[-0.008488935, 0.034309033, -0.23201317, 0.074...","[-0.06461532, -0.049454708, -0.054552767, -0.0...","[101, 2025, 2069, 2003, 2009, 1037, 19424, 213...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,3,2655.0,0.308445
3,"In director Sooraj Barjatya's Vivah,20-somethi...",0,142,"[0.011529899, 0.07033135, -0.016396757, -0.016...","[-0.12023467, 0.27032727, -0.16782328, -0.0713...","[-0.0029073274, 0.037866402, 0.026605502, -0.0...","[101, 1999, 2472, 17111, 14220, 3347, 3900, 21...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,8,2194.0,0.349822
4,"Maybe you shouldn't compare, but Wild Style an...",0,173,"[-0.03354113, 0.034324, 0.0148819145, -0.00328...","[-0.3127217, 0.102947414, -0.047185007, 0.0484...","[-0.049723655, -0.0010385206, -0.0062440704, -...","[101, 2672, 2017, 5807, 1005, 1056, 12826, 101...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,10,8724.0,0.339425
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
595,"While the story is sweet, and the dancing and ...",1,24017,"[-0.03852864, 0.032753985, -0.032079127, 0.013...","[-0.16448122, -0.07015948, -0.25898105, 0.0735...","[-0.0526425, 0.012932369, 0.002774512, 0.00415...","[101, 2096, 1996, 2466, 2003, 4086, 1010, 1998...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,9,4635.0,0.565121
596,It was very heart-warming. As an expatriated f...,1,24291,"[-0.027180221, 0.004869805, -0.0050957943, 0.0...","[-0.08668917, -0.28511098, -0.273917, 0.139006...","[-0.025526993, 0.040672258, -0.11157379, -0.16...","[101, 2009, 2001, 2200, 2540, 1011, 12959, 101...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,10,1270.0,0.322476
597,Directed by Brian De Palma and written by Oliv...,1,24302,"[-0.019601606, 0.04224149, -0.021460632, 0.044...","[-0.20593545, 0.1765558, -0.12829691, 0.399145...","[-0.05924157, -0.0105029335, -0.08769067, -0.0...","[101, 2856, 2011, 4422, 2139, 23985, 1998, 251...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,6,2116.0,0.495365
598,this was a very good movie i wished i could fi...,1,24538,"[0.009553274, 0.031344377, 0.0038972297, -0.02...","[-0.12809247, -0.12758648, -0.33170307, -0.071...","[-0.08255486, -0.04266798, -0.049890228, -0.04...","[101, 2023, 2001, 1037, 2200, 2204, 3185, 1045...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,1,993.0,0.433209


In [62]:
pd.merge(
    df.loc[df["train_uid"].notna(), ["uid", "train_uid", "comment_text", "labels"]],
    df[["labels", "uid", "comment_text"]],
    left_on="train_uid",
    right_on="uid",
    how="inner",
    suffixes=["", "_train"],
)

Unnamed: 0,uid,train_uid,comment_text,labels,labels_train,uid_train,comment_text_train
0,1778,17430,"FUCK YOU, GO SUCK SOME DOG COCK! YOU FUCKING D...",1,1,17430,faggot faggot faggot faggot faggot faggot fagg...
1,153620,17430,your retarded your retardedyour retardedyour r...,1,1,17430,faggot faggot faggot faggot faggot faggot fagg...
2,3473,12,mother fucker \n\nfuck off dick face,1,1,12,FUCK FUCKITY FUCK FUCK FUCK
3,18971,12,FUCK YOU FUCK YOU FUCK YOUa,1,1,12,FUCK FUCKITY FUCK FUCK FUCK
4,22022,12,fuck you bitch nigger slut ass fuck whore,1,1,12,FUCK FUCKITY FUCK FUCK FUCK
...,...,...,...,...,...,...,...
95,133062,3473,In Response - FUCK YOU,1,1,3473,mother fucker \n\nfuck off dick face
96,137282,98,FUCK OF NAZI PROPAGANDA\nodpierdalcie sie jeba...,1,0,98,Pilecki's organisation1940-43 \n\nYears before...
97,140537,68120,what's up asshole bitch,1,1,68120,Fuck Off \n\nFuck Off i do what i want bitch
98,146906,47126,Stop talking to me \n\nShut the fuck up you st...,1,1,47126,hay fuck youhay fuck you \n\nhay fuck you


In [14]:
ds.data.groupby("uid").size().sort_values()

uid
0         1
106377    1
106378    1
106379    1
106380    1
         ..
53191     1
53192     1
53193     1
53186     1
159570    1
Length: 159571, dtype: int64

In [19]:
df = ds.data

df.loc[df["is_labelled"] == True].groupby("labelling_round").size()

labelling_round
-1    100
 1     50
 2     50
dtype: int64

In [46]:
assert df["uid"].nunique() == len(df)

In [50]:
i = df["train_uid"][df["train_uid"].notna()].unique()

In [35]:
df["uid"].isin(df["train_uid"]).sum() / df["train_uid"].notna().sum()

0.42

In [39]:
df.loc[df["train_uid"].notna(), "train_uid"].nunique()

42

In [40]:
df.loc[df["train_uid"].notna(), "labels"].value_counts()

labels
1    93
0     7
Name: count, dtype: int64

In [11]:
ds.data.loc[~ds.data["train_uid"].isna()]

Unnamed: 0,id,comment_text,labels,uid,embedding_all-mpnet-base-v2,embedding_multi-qa-mpnet-base-dot-v1,embedding_all-MiniLM-L12-v2,input_ids,token_type_ids,attention_mask,is_labelled,is_validation,labelling_round,train_uid,dists
1778,1644bccf6f37f19e,"FUCK YOU, GO SUCK SOME DOG COCK! YOU FUCKING D...",1,1778,"[0.01764762, 0.05288508, -0.010527764, -0.0131...","[-0.045378618, -0.25800213, -0.44898406, -0.27...","[-0.041358825, 0.015455857, 0.0055711074, -0.0...","[101, 6616, 2017, 1010, 2175, 11891, 2070, 389...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,2,17430,0.538952
3473,6aa03f857ba747da,mother fucker \n\nfuck off dick face,1,3473,"[-0.021624373, 0.03843027, -0.007751778, 0.059...","[-0.16709238, -0.5578038, -0.39088944, -0.1443...","[-0.060571082, 0.044763703, 0.0028691227, 0.00...","[101, 2388, 6616, 2121, 6616, 2125, 5980, 2227...","[0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1]",True,False,1,12,0.573351
7928,d6ab1b62b0d08a2e,why don't you suck my ass until your lips blee...,1,7928,"[-0.009330071, 0.098941974, 0.013997296, 0.010...","[-0.023266658, -0.18005346, -0.28239602, -0.11...","[-0.055138674, -0.016808288, 0.017413128, -0.0...","[101, 2339, 2123, 1005, 1056, 2017, 11891, 202...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,1,24,0.500528
10187,18f8afbb8f66304c,YOU ARE A F**KING LIAR!,1,10187,"[0.014268233, 0.050565224, -0.020841537, 0.091...","[-0.095857754, -0.22293137, -0.4484912, 0.0633...","[-0.06453303, 0.025427138, 0.008142441, -0.010...","[101, 2017, 2024, 1037, 1042, 1008, 1008, 2332...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",True,False,2,59,0.600376
10648,b9875864a7d2a30c,Alrighty\n\nYou can go fuck yourself.,1,10648,"[-0.026609253, 0.09604505, -0.003319245, -0.02...","[-0.16098726, -0.103996314, -0.40565914, -0.17...","[0.035441663, -0.083025225, 0.012442009, 0.002...","[101, 10303, 2100, 2017, 2064, 2175, 6616, 442...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",True,False,2,18971,0.302180
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
153581,0468dd91f5506285,I was speaking the truth u son of a bitch go f...,1,153581,"[0.0038699205, 0.03174462, 0.0025089267, 0.002...","[0.08494834, -0.21561271, -0.34298617, -0.2999...","[-0.031451266, -0.020049753, 0.017962962, -0.0...","[101, 1045, 2001, 4092, 1996, 3606, 1057, 2365...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,2,111030,0.473465
153620,e1469c196b2c8724,your retarded your retardedyour retardedyour r...,1,153620,"[0.050225183, 0.022684412, 0.019161776, -0.020...","[-0.08857655, -0.41304967, -0.32461682, -0.113...","[0.015227961, -0.05367273, 0.08996908, 0.03628...","[101, 2115, 2128, 7559, 5732, 2115, 2128, 7559...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,2,17430,0.501090
155097,7e40d2ad4ef45b62,FUCK YOU! FUCK YOU! FUCK YOU! FUCK YOU! \nFUCK...,1,155097,"[-0.033801775, 0.037717223, -0.010647806, -0.0...","[-0.19134074, -0.17650607, -0.44849518, -0.236...","[-0.0728808, 0.111763485, -0.03898648, -0.0915...","[101, 6616, 2017, 999, 6616, 2017, 999, 6616, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",True,False,1,12,0.490013
159013,365a53aa15419ac2,Fuck you Motha Fucka \n\nPenis,1,159013,"[-0.0014293266, -0.03462546, -0.014728903, 0.0...","[-0.14177944, -0.6943717, -0.44628122, -0.1260...","[-0.09731447, 0.11882899, -0.05196268, -0.0270...","[101, 6616, 2017, 5820, 2050, 6616, 2050, 1908...","[0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1]",True,False,1,12,0.480344


In [None]:
r = estimator.fit(train_loader=ds.test_loader(), max_epochs=1, limit_train_batches=2)

In [None]:
estimator.current_pool["train_ids"]

In [None]:
a, b = ds.get_embeddings([96, 156443]).tolist()

In [None]:
datamodule = ClassificationActiveDataModule.from_dataset_dict(
    dataset_dict, tokenizer=tokenizer
)

In [None]:
datamodule.load_index(meta["hnsw_index_path"], embedding_dim=meta["embedding_dim"])

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    meta["name_or_path"],
    id2label=datamodule.id2label,
    label2id=datamodule.label2id,
    num_labels=len(datamodule.labels),
)
active_estimator = SimilaritySearchStrategyForSequenceClassification(
    model=model, seed=42
)

In [None]:
active_estimator.active_fit(
    max_rounds=2,
    query_size=100,
    active_datamodule=datamodule,
    limit_test_batches=10,
)

In [None]:
active_estimator.progress_tracker.budget_tracker

In [None]:
datamodule.train_size()

In [None]:
loader = datamodule.train_loader()
batch = next(iter(loader))
_ = batch.pop("on_cpu")
batch_size = loader.batch_size

In [None]:
import torch

In [None]:
loss = model(**batch).loss

In [None]:
grads = torch.autograd.grad(loss, list(model.parameters()))

In [None]:
def compute_grad(model, input_ids, attn_mask, target):
    input_ids = input_ids.unsqueeze(0)  # prepend batch dimension for processing
    attn_mask = attn_mask.unsqueeze(0)  # prepend batch dimension for processing
    target = target.unsqueeze(0)
    loss = model(input_ids=input_ids, attention_mask=attn_mask, labels=target).loss
    return torch.autograd.grad(loss, list(model.parameters()))

In [None]:
input_ids, attention_mask, target = (
    batch["input_ids"],
    batch["attention_mask"],
    batch["labels"],
)

In [None]:
def select(i):
    return input_ids[i], attention_mask[i], target[i]

In [None]:
norms = np.array(
    [
        [g.norm(2).item() for g in compute_grad(model, *select(i))]
        for i in range(batch_size)
    ]
)

In [None]:
norms.shape

In [None]:
from functorch import grad, make_functional_with_buffers, vmap

fmodel, params, buffers = make_functional_with_buffers(model)

In [None]:
def compute_loss_stateless_model(
    fmodel, params, buffers, input_ids, att_mask, label
):
    input_ids = input_ids.unsqueeze(0)
    att_mask = att_mask.unsqueeze(0)
    label = label.unsqueeze(0)

    return fmodel(
        params,
        buffers,
        input_ids=input_ids,
        attention_mask=att_mask,
        labels=label,
    ).loss

In [None]:
compute_loss_stateless_model(fmodel, params, buffers, *select(0))

In [None]:
ft_compute_grad = grad(compute_loss_stateless_model, argnums=1)

In [None]:
ft_compute_grad(fmodel, params, buffers, *select(0))[0].requires_grad

In [None]:
%%timeit
fnorms = np.array(
    [
        [
            g.norm(2).item()
            for g in ft_compute_grad(fmodel, params, buffers, *select(i))
        ]
        for i in range(batch_size)
    ]
)

In [None]:
def compute_norm(fmodel, params, buffers, input_ids, attention_mask, target):
    grads = ft_compute_grad(
        fmodel, params, buffers, input_ids, attention_mask, target
    )
    return tuple(g.norm() for g in grads)

In [None]:
ft_compute_sample_grad = vmap(
    compute_norm, in_dims=(None, None, None, 0, 0, 0), randomness="same"
)

In [None]:
fnorms_vmap = torch.stack(
    ft_compute_sample_grad(
        fmodel, params, buffers, input_ids, attention_mask, target
    )
).T

In [None]:
fnorms_vmap

In [None]:
fnorms

In [None]:
# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:
for per_sample_grad, ft_per_sample_grad in zip(
    per_sample_grads, ft_per_sample_grads
):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

In [None]:
ft_per_sample_grad

In [None]:
import time
from pathlib import Path

import pandas as pd
import srsly
from datasets import load_from_disk
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from src.data.active_datamodule import ActiveClassificationDataModule
from src.data.datamodule import ClassificationDataModule
from src.enums import SpecialKeys
from src.estimator import Estimator
from src.huggingface import (
    EstimatorForSequenceClassification,
    UncertaintyBasedStrategyForSequenceClassification,
)

In [None]:
data_path = Path("../data/prepared/ag_news")
dataset_dict = load_from_disk(data_path)
metadata = srsly.read_yaml(data_path / "metadata.yaml")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(metadata["name_or_path"])

In [None]:
datamodule = ClassificationDataModule.from_dataset_dict(
    dataset_dict, tokenizer=tokenizer
)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    metadata["name_or_path"],
    num_labels=len(datamodule.labels),
    id2label=datamodule.id2label,
    label2id=datamodule.label2id,
)

In [None]:
estimator = EstimatorForSequenceClassification(model)

In [None]:
out = estimator.fit(
    train_loader=datamodule.train_loader(),
    validation_loader=datamodule.validation_loader(),
    limit_train_batches=10,
    limit_validation_batches=10,
    max_epochs=1,
)

In [None]:
active_estimator = UncertaintyBasedStrategyForSequenceClassification(
    model, score_fn="margin_confidence"
)

In [None]:
out = active_estimator.fit(
    train_loader=datamodule.train_loader(),
    validation_loader=datamodule.validation_loader(),
    limit_train_batches=10,
    limit_validation_batches=10,
)

In [None]:
active_datamodule = ActiveClassificationDataModule.from_dataset_dict(
    dataset_dict,
    tokenizer=tokenizer,
)

In [None]:
active_out = active_estimator.active_fit(
    active_datamodule=active_datamodule,
    max_rounds=3,
    query_size=50,
    validation_perc=0.3,
    fit_kwargs={
        "max_epochs": 3,
        "limit_train_batches": 3,
        "limit_validation_batches": 3,
    },
    test_kwargs={"limit_batches": 3},
    pool_kwargs={"limit_batches": 3},
)

In [None]:
active_datamodule.save_labelled_dataset("results")

In [None]:
df = pd.read_parquet("results/labelled_dataset.parquet")

In [None]:
df