In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from datasets import load_dataset, load_from_disk
from energizer.active_learning.datastores.classification import ActivePandasDataStoreForSequenceClassification

from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

In [8]:
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
embedder = SentenceTransformer("all-MiniLM-L6-v2")

In [9]:
dataset_dict = load_from_disk("../../data/prepared/agnews_bert-tiny/")

In [11]:
datastore = ActivePandasDataStoreForSequenceClassification.from_dataset_dict(
    dataset_dict=dataset_dict,  # type:ignore
    input_names=["input_ids", "attention_mask"],
    target_name="labels",
    uid_name="uid",
    tokenizer=tokenizer,
)

In [13]:
datastore.show_batch("test")

{'input_ids': tensor([[  101,  2470,  2003,  5791,  1999,  4367,  1996, 25935,  9949,  5080,
           9338,  2003, 21366,  2000, 13467, 10908,  1012,   102]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 <InputKeys.LABELS: 'labels'>: tensor([1]),
 <InputKeys.ON_CPU: 'on_cpu'>: {<SpecialKeys.ID: 'uid'>: [122939]}}

In [14]:
index_path = "../../data/processed/agnews/all-mpnet-base-v2_cosine.bin"
meta_path = "../../data/processed/agnews/all-mpnet-base-v2_cosine.json"

datastore.load_index(index_path, meta_path)

In [15]:
datastore.index

<hnswlib.Index(space='cosine', dim=768)>

In [None]:
# load data
dataset_dict = load_dataset("ag_news").rename_columns({"label": "labels"})

# embed training data
dataset_dict["train"] = (
    dataset_dict["train"]
    .map(
        lambda ex: {"embedding": embedder.encode(ex["text"], device="cuda", batch_size=512)},
        batched=True,
    )
)

dataset_dict.save_to_disk("agnews")

In [None]:
dataset_dict = load_from_disk("agnews").map(lambda ex: tokenizer(ex["text"], return_token_type_ids=False), batched=True)

In [None]:
dataset_dict["train"].features, dataset_dict["test"].features

In [None]:
datastore = PandasDataStoreForSequenceClassification()
datastore.from_dataset_dict(
    dataset_dict=dataset_dict, 
    tokenizer=tokenizer,
    # on_cpu=["embedding", "text"],
    input_names=["input_ids", "attention_mask"],
    target_name="labels",
)

In [None]:
datastore.add_index("embedding", metric="l2")

In [None]:
datastore.save("./agnews_datastore")

In [None]:
ds = PandasDataStoreForSequenceClassification.load("./agnews_datastore")

In [None]:
ds.input_names, ds.target_name, ds.on_cpu

In [None]:
datastore.input_names, datastore.target_name, datastore.on_cpu, datastore._features

In [None]:
ds.input_names, ds.target_name, ds.on_cpu, ds._features

In [None]:
datastore.label_distribution(), ds.label_distribution()

In [None]:
query = datastore.data.iloc[0]
ids, dists = datastore.search(query["embedding"], 10, query_in_set=True)

print(f"query: {query.text}\nresults:")
print("  - " + "\n  - ".join(datastore.get_by_ids(ids[0]).text))

In [None]:
query = ds.data.iloc[0]
ids, dists = ds.search(query["embedding"], 10, query_in_set=True)

print(f"query: {query.text}\nresults:")
print("  - " + "\n  - ".join(ds.get_by_ids(ids[0]).text))

In [None]:
datastore.labels, ds.labels

In [None]:
datastore.label2id, ds.label2id

In [None]:
datastore.train_dataset(), ds.train_dataset()

In [None]:
datastore.label(indices=[0, 1], round=1, validation_perc=0.5)
ds.label(indices=[0, 1], round=1, validation_perc=0.5)

In [None]:
datastore.train_dataset(), datastore.train_dataset(0), ds.train_dataset(), ds.train_dataset(0)

In [None]:
datastore.pool_dataset(), datastore.pool_dataset(0), ds.pool_dataset(), ds.pool_dataset(0)

In [None]:
datastore.validation_dataset(), datastore.validation_dataset(0), ds.validation_dataset(), ds.validation_dataset(0)

In [None]:
datastore.test_dataset(), ds.test_dataset()

In [None]:
datastore.prepare_for_loading(), ds.prepare_for_loading()

In [None]:
datastore.show_batch(), ds.show_batch()