In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
from datasets import load_dataset, load_from_disk, Features, Value, ClassLabel, Sequence, Dataset
from energizer.datastores.pandas import PandasDataStoreForSequenceClassification

from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import srsly
from pathlib import Path

from energizer.datastores.pandas import sample
from sklearn.utils import check_random_state
from sklearn.manifold import TSNE, MDS
import numpy as np
import seaborn as sns

data_dir = Path("../data")

---
AGNEWS

In [None]:
dataset_dict = (
    load_dataset("ag_news")
    .rename_columns({"label": "labels"})
    .map(lambda ex: {"embedding": embedder.encode(ex["text"], device="cuda", batch_size=512)}, batched=True, batch_size=1024)
)

data_path = data_dir / "processed" / "agnews"
dataset_dict.save_to_disk(data_path)
srsly.write_json(data_path / "index_metadata.json", meta)

In [None]:
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset_dict = load_from_disk(data_path).map(lambda ex: tokenizer(ex["text"], return_token_type_ids=False), batched=True)

data_path = data_dir / "prepared" / "agnews_bert_tiny"
dataset_dict.save_to_disk(data_path)
srsly.write_json(data_path / "metadata.json", {"name_or_path": model_name})

In [5]:
features = Features(
    {
        'text': Value(dtype='string'),
        'original_labels': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech']),
        'embedding': Sequence(feature=Value(dtype='float32'), length=-1),
        'labels': ClassLabel(names=['Others', 'World']),
    }
)

data_path = data_dir / "processed" / "agnews"

ds_dict = (
    load_from_disk(data_path)
    .rename_columns({"labels": "original_labels"})
    .map(lambda ex: {"labels": [int(i == 0) for i in ex["original_labels"]]}, batched=True, features=features)
)

# downsample the minority class
df = ds_dict["train"].to_pandas()
rng = check_random_state(42)

df["sampled"] = False
ids = df.loc[df["original_labels"] == 0].sample(int(0.05 * len(df)), random_state=rng).index
df.loc[df.index.isin(ids), "sampled"] = True
new_df = df.loc[(df["labels"] == 0) | (df["sampled"] == True)].drop(columns=["sampled"])

ds_dict["train"] = Dataset.from_pandas(new_df, features=features, preserve_index=False)
ds_dict.save_to_disk("../data/processed/agnews_binarised")

Loading cached processed dataset at /home/pl487/allset/data/processed/agnews/train/cache-9144a554b5bf98f0.arrow
Loading cached processed dataset at /home/pl487/allset/data/processed/agnews/test/cache-ea336ed6e6612afe.arrow


Saving the dataset (0/1 shards):   0%|          | 0/96000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7600 [00:00<?, ? examples/s]

In [11]:
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

ds_dict = ds_dict.map(lambda ex: tokenizer(ex["text"], return_token_type_ids=False), batched=True)

datastore = PandasDataStoreForSequenceClassification()
datastore.from_dataset_dict(ds_dict, input_names=["input_ids", "attention_mask"], target_name="labels", tokenizer=tokenizer)
datastore.add_index("embedding", metric="l2")
datastore.save(data_dir / "prepared" / "agnews_binarised_bert-tiny")

Map:   0%|          | 0/96000 [00:00<?, ? examples/s]

Loading cached processed dataset at /home/pl487/allset/data/processed/agnews/test/cache-33ea29df3c931399.arrow


Saving the dataset (0/1 shards):   0%|          | 0/96000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7600 [00:00<?, ? examples/s]

In [None]:
ds = PandasDataStoreForSequenceClassification.load(data_dir / "prepared" / "agnews_bert-tiny")

In [None]:
query = ds.data.iloc[200]
ids, _ = ds.search(query.embedding, 100)

In [None]:
print(f"query:\n   [{ds.id2label[query.labels]}] {query.text}\nresults:")
print("   " + "\n   ".join(ds.get_by_ids(ids[0])[["labels", "text"]].apply(lambda ex: f"[{ds.id2label[ex['labels']]}] {ex['text']}", axis=1)))

In [None]:
ds = PandasDataStoreForSequenceClassification.load("./agnews_datastore")

In [None]:
ds.input_names, ds.target_name, ds.on_cpu

In [None]:
datastore.input_names, datastore.target_name, datastore.on_cpu, datastore._features

In [None]:
ds.input_names, ds.target_name, ds.on_cpu, ds._features

In [None]:
datastore.label_distribution(), ds.label_distribution()

In [None]:
query = datastore.data.iloc[0]
ids, dists = datastore.search(query["embedding"], 10, query_in_set=True)

print(f"query: {query.text}\nresults:")
print("  - " + "\n  - ".join(datastore.get_by_ids(ids[0]).text))

In [None]:
query = ds.data.iloc[0]
ids, dists = ds.search(query["embedding"], 10, query_in_set=True)

print(f"query: {query.text}\nresults:")
print("  - " + "\n  - ".join(ds.get_by_ids(ids[0]).text))

In [None]:
datastore.labels, ds.labels

In [None]:
datastore.label2id, ds.label2id

In [None]:
datastore.train_dataset(), ds.train_dataset()

In [None]:
datastore.label(indices=[0, 1], round=1, validation_perc=0.5)
ds.label(indices=[0, 1], round=1, validation_perc=0.5)

In [None]:
datastore.train_dataset(), datastore.train_dataset(0), ds.train_dataset(), ds.train_dataset(0)

In [None]:
datastore.pool_dataset(), datastore.pool_dataset(0), ds.pool_dataset(), ds.pool_dataset(0)

In [None]:
datastore.validation_dataset(), datastore.validation_dataset(0), ds.validation_dataset(), ds.validation_dataset(0)

In [None]:
datastore.test_dataset(), ds.test_dataset()

In [None]:
datastore.prepare_for_loading(), ds.prepare_for_loading()

In [None]:
datastore.show_batch(), ds.show_batch()