In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import load_dataset, load_from_disk
from energizer.datastores.pandas import PandasDataStoreForSequenceClassification

from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

In [3]:
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
embedder = SentenceTransformer("all-MiniLM-L6-v2")

In [4]:
# load data
dataset_dict = load_dataset("ag_news").rename_columns({"label": "labels"})

# embed training data
dataset_dict["train"] = (
    dataset_dict["train"]
    .map(
        lambda ex: {"embedding": embedder.encode(ex["text"], device="cuda", batch_size=512)},
        batched=True,
    )
)

dataset_dict.save_to_disk("agnews")

Found cached dataset ag_news (/home/pl487/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


  0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/120000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/120000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7600 [00:00<?, ? examples/s]

In [5]:
dataset_dict = load_from_disk("agnews").map(lambda ex: tokenizer(ex["text"], return_token_type_ids=False), batched=True)

Map:   0%|          | 0/120000 [00:00<?, ? examples/s]

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]

In [6]:
dataset_dict["train"].features, dataset_dict["test"].features

({'text': Value(dtype='string', id=None),
  'labels': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None),
  'embedding': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
  'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
  'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)},
 {'text': Value(dtype='string', id=None),
  'labels': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None),
  'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
  'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)})

In [7]:
datastore = PandasDataStoreForSequenceClassification()
datastore.from_dataset_dict(
    dataset_dict=dataset_dict, 
    tokenizer=tokenizer,
    # on_cpu=["embedding", "text"],
    input_names=["input_ids", "attention_mask"],
    target_name="labels",
)

In [8]:
datastore.add_index("embedding", metric="l2")

In [9]:
datastore.save("./agnews_datastore")

Saving the dataset (0/1 shards):   0%|          | 0/120000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7600 [00:00<?, ? examples/s]

In [10]:
ds = PandasDataStoreForSequenceClassification.load("./agnews_datastore")

In [11]:
ds.input_names, ds.target_name, ds.on_cpu

(['input_ids', 'attention_mask'], 'labels', [])

In [12]:
datastore.input_names, datastore.target_name, datastore.on_cpu, datastore._features

(['input_ids', 'attention_mask'],
 'labels',
 [],
 {'text': Value(dtype='string', id=None),
  'labels': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None),
  'embedding': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
  'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
  'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
  'unique_id': Value(dtype='int64', id=None),
  'is_labelled': Value(dtype='bool', id=None),
  'is_validation': Value(dtype='bool', id=None),
  'labelling_round': Value(dtype='int64', id=None)})

In [13]:
ds.input_names, ds.target_name, ds.on_cpu, ds._features

(['input_ids', 'attention_mask'],
 'labels',
 [],
 {'text': Value(dtype='string', id=None),
  'labels': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None),
  'embedding': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
  'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
  'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
  'unique_id': Value(dtype='int64', id=None),
  'is_labelled': Value(dtype='bool', id=None),
  'is_validation': Value(dtype='bool', id=None),
  'labelling_round': Value(dtype='int64', id=None)})

In [14]:
datastore.label_distribution(), ds.label_distribution()

({2: 30000, 3: 30000, 1: 30000, 0: 30000},
 {2: 30000, 3: 30000, 1: 30000, 0: 30000})

In [15]:
query = datastore.data.iloc[0]
ids, dists = datastore.search(query["embedding"], 10, query_in_set=True)

print(f"query: {query.text}\nresults:")
print("  - " + "\n  - ".join(datastore.get_by_ids(ids[0]).text))

query: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
results:
  - Wall St. Bears Claw Back Into the Black  NEW YORK (Reuters) - Short-sellers, Wall Street's dwindling  band of ultra-cynics, are seeing green again.
  - Wall St. Seen Lower on Oil; Google Eyed (Reuters) Reuters - Sky-high oil prices are likely to\pressure Wall Street once again on Thursday, while earnings\news from tech giants Ciena (CIEN.N) and Nortel (NT.TO) and\Google's (GOOG.O) awaited Nasdaq debut will also steer\sentiment.
  - Wall St. Seen Rising as Oil Prices Slip (Reuters) Reuters - U.S. stock futures pointed toward a\higher Wall Street open on Tuesday after crude oil prices fell\for a third day, easing investors' fears that costly oil would\squeeze company profits and slow growth.
  - Wall St. Seen Sliding After Jobless Data  NEW YORK (Reuters) - U.S. stock futures pointed to a  slightly lower open on Wall Street  

In [16]:
query = ds.data.iloc[0]
ids, dists = ds.search(query["embedding"], 10, query_in_set=True)

print(f"query: {query.text}\nresults:")
print("  - " + "\n  - ".join(ds.get_by_ids(ids[0]).text))

query: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
results:
  - Wall St. Bears Claw Back Into the Black  NEW YORK (Reuters) - Short-sellers, Wall Street's dwindling  band of ultra-cynics, are seeing green again.
  - Wall St. Seen Lower on Oil; Google Eyed (Reuters) Reuters - Sky-high oil prices are likely to\pressure Wall Street once again on Thursday, while earnings\news from tech giants Ciena (CIEN.N) and Nortel (NT.TO) and\Google's (GOOG.O) awaited Nasdaq debut will also steer\sentiment.
  - Wall St. Seen Rising as Oil Prices Slip (Reuters) Reuters - U.S. stock futures pointed toward a\higher Wall Street open on Tuesday after crude oil prices fell\for a third day, easing investors' fears that costly oil would\squeeze company profits and slow growth.
  - Wall St. Seen Sliding After Jobless Data  NEW YORK (Reuters) - U.S. stock futures pointed to a  slightly lower open on Wall Street  

In [17]:
datastore.labels, ds.labels

(['World', 'Sports', 'Business', 'Sci/Tech'],
 ['World', 'Sports', 'Business', 'Sci/Tech'])

In [18]:
datastore.label2id, ds.label2id

({'World': 0, 'Sports': 1, 'Business': 2, 'Sci/Tech': 3},
 {'World': 0, 'Sports': 1, 'Business': 2, 'Sci/Tech': 3})

In [19]:
datastore.train_dataset(), ds.train_dataset()

(None, None)

In [20]:
datastore.label(indices=[0, 1], round=1, validation_perc=0.5)
ds.label(indices=[0, 1], round=1, validation_perc=0.5)

2

In [21]:
datastore.train_dataset(), datastore.train_dataset(0), ds.train_dataset(), ds.train_dataset(0)

(Dataset({
     features: ['text', 'labels', 'embedding', 'input_ids', 'attention_mask', 'unique_id', 'is_labelled', 'is_validation', 'labelling_round'],
     num_rows: 1
 }),
 None,
 Dataset({
     features: ['text', 'labels', 'embedding', 'input_ids', 'attention_mask', 'unique_id', 'is_labelled', 'is_validation', 'labelling_round'],
     num_rows: 1
 }),
 None)

In [22]:
datastore.pool_dataset(), datastore.pool_dataset(0), ds.pool_dataset(), ds.pool_dataset(0)

(Dataset({
     features: ['text', 'embedding', 'input_ids', 'attention_mask', 'unique_id', 'is_labelled', 'is_validation', 'labelling_round'],
     num_rows: 119998
 }),
 Dataset({
     features: ['text', 'embedding', 'input_ids', 'attention_mask', 'unique_id', 'is_labelled', 'is_validation', 'labelling_round'],
     num_rows: 120000
 }),
 Dataset({
     features: ['text', 'embedding', 'input_ids', 'attention_mask', 'unique_id', 'is_labelled', 'is_validation', 'labelling_round'],
     num_rows: 119998
 }),
 Dataset({
     features: ['text', 'embedding', 'input_ids', 'attention_mask', 'unique_id', 'is_labelled', 'is_validation', 'labelling_round'],
     num_rows: 120000
 }))

In [23]:
datastore.validation_dataset(), datastore.validation_dataset(0), ds.validation_dataset(), ds.validation_dataset(0)

(Dataset({
     features: ['text', 'labels', 'embedding', 'input_ids', 'attention_mask', 'unique_id', 'is_labelled', 'is_validation', 'labelling_round'],
     num_rows: 1
 }),
 None,
 Dataset({
     features: ['text', 'labels', 'embedding', 'input_ids', 'attention_mask', 'unique_id', 'is_labelled', 'is_validation', 'labelling_round'],
     num_rows: 1
 }),
 None)

In [24]:
datastore.test_dataset(), ds.test_dataset()

(Dataset({
     features: ['text', 'labels', 'input_ids', 'attention_mask'],
     num_rows: 7600
 }),
 Dataset({
     features: ['text', 'labels', 'input_ids', 'attention_mask'],
     num_rows: 7600
 }))

In [25]:
datastore.prepare_for_loading(), ds.prepare_for_loading()

(None, None)

In [26]:
datastore.show_batch(), ds.show_batch()

({'input_ids': tensor([[  101,  2813,  2358,  1012,  6468, 15020,  2067,  2046,  1996,  2304,
            1006, 26665,  1007, 26665,  1011,  2460,  1011, 19041,  1010,  2813,
            2395,  1005,  1055,  1040, 11101,  2989,  1032,  2316,  1997, 11087,
            1011, 22330,  8713,  2015,  1010,  2024,  3773,  2665,  2153,  1012,
             102]]),
  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
  'labels': tensor([2]),
  <InputKeys.ON_CPU: 'on_cpu'>: {<SpecialKeys.ID: 'unique_id'>: [0]}},
 {'input_ids': tensor([[  101,  2813,  2358,  1012,  6468, 15020,  2067,  2046,  1996,  2304,
            1006, 26665,  1007, 26665,  1011,  2460,  1011, 19041,  1010,  2813,
            2395,  1005,  1055,  1040, 11101,  2989,  1032,  2316,  1997, 11087,
            1011, 22330,  8713,  2015,  1010,  2024,  3773,  2665,  2153,  1012,
             102]]),
  'attention_mask': te