In [82]:
import multiprocessing
from datasets import load_dataset, DatasetDict
from transformers import pipeline

In [83]:
seeds = ["42", "69", "1337"]
cpu_count = multiprocessing.cpu_count()
model_path = "../models"
dataset_path = "../datasets"

In [84]:
models = dict()
for seed in seeds:
    models[seed] = pipeline("text-classification", model=f"{model_path}/hypothesis-only/{seed}")

In [85]:
mnli = load_dataset("multi_nli")

mnli["train"] = mnli["train"].select(range(0, 100))

Found cached dataset multi_nli (/Users/I518253/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


  0%|          | 0/3 [00:00<?, ?it/s]

In [86]:
label_mapping = {
    "ENTAILMENT": 0,
    "NEUTRAL": 1,
    "CONTRADICTION": 2
}

def add_prediction_columns(record, model, seed):
    prediction = model({"text":record['premise'], "text_pair":record['hypothesis']})
    label = prediction["label"]
    score = prediction["score"]
    return { f"prediction_{seed}": label_mapping[label], f"score_{seed}": score }

for seed in seeds:
    print(f"Adding predictions of model {seed}...")
    model = models[seed]
    mnli["train"] = mnli["train"].map(add_prediction_columns, fn_kwargs={ "model": model, "seed": seed })
mnli.save_to_disk(f"{dataset_path}/mnli_with_predictions")

Adding predictions of model 42...


Loading cached processed dataset at /Users/I518253/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-8db6083a8e077fe6.arrow


Adding predictions of model 69...


Loading cached processed dataset at /Users/I518253/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-85d1a21ad6eb7be7.arrow


Adding predictions of model 1337...


Loading cached processed dataset at /Users/I518253/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-5fde458676fe6892.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/9815 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/9832 [00:00<?, ? examples/s]

In [87]:
mnli["train"][2]

{'promptID': 134793,
 'pairID': '134793e',
 'premise': 'One of our number will carry out your instructions minutely.',
 'premise_binary_parse': '( ( One ( of ( our number ) ) ) ( ( will ( ( ( carry out ) ( your instructions ) ) minutely ) ) . ) )',
 'premise_parse': '(ROOT (S (NP (NP (CD One)) (PP (IN of) (NP (PRP$ our) (NN number)))) (VP (MD will) (VP (VB carry) (PRT (RP out)) (NP (PRP$ your) (NNS instructions)) (ADVP (RB minutely)))) (. .)))',
 'hypothesis': 'A member of my team will execute your orders with immense precision.',
 'hypothesis_binary_parse': '( ( ( A member ) ( of ( my team ) ) ) ( ( will ( ( execute ( your orders ) ) ( with ( immense precision ) ) ) ) . ) )',
 'hypothesis_parse': '(ROOT (S (NP (NP (DT A) (NN member)) (PP (IN of) (NP (PRP$ my) (NN team)))) (VP (MD will) (VP (VB execute) (NP (PRP$ your) (NNS orders)) (PP (IN with) (NP (JJ immense) (NN precision))))) (. .)))',
 'genre': 'fiction',
 'label': 0,
 'prediction_42': 1,
 'score_42': 0.8558771014213562,
 'predi

In [88]:
def correct_by_at_least(record, at_least):
    predictions = []
    for seed in seeds:
        predictions.append(record[f"prediction_{seed}"])
    return len(list(filter(lambda pred: pred == record["label"], predictions))) >= at_least

In [90]:
mnli_at_least_two = DatasetDict()
print(f"Filtering split by at least two...")
mnli_at_least_two["train"] = mnli["train"].filter(correct_by_at_least, fn_kwargs={ "at_least": 2 }, num_proc=cpu_count)
print(f"Saving filtered datset...")
mnli_at_least_two.save_to_disk(f"{dataset_path}/mnli_at_least_two")

Loading cached processed dataset at /Users/I518253/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-76be5b57b56b6b81_*_of_00010.arrow


Filtering split by at least two...


Flattening the indices:   0%|          | 0/50 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/50 [00:00<?, ? examples/s]

In [91]:
mnli_three = DatasetDict()
print(f"Filtering split by three...")
mnli_three["train"] = mnli["train"].filter(correct_by_at_least, fn_kwargs={ "at_least": 3 }, num_proc=cpu_count)
print(f"Saving filtered datset...")
mnli_three.save_to_disk(f"{dataset_path}/mnli_three")

Filtering split by three...


Filter (num_proc=10):   0%|          | 0/100 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/37 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/37 [00:00<?, ? examples/s]