In [37]:
from datasets import load_dataset
from transformers import RobertaForSequenceClassification, RobertaModel, RobertaConfig, RobertaTokenizer
from ferret import Benchmark
from ferret import SHAPExplainer, LIMEExplainer, IntegratedGradientExplainer
from statistics import mean, median
import json

In [38]:
seed = 42

In [39]:
esnli = load_dataset("../datasets/esnli.py", split='validation')

Found cached dataset esnli (/home/students/loeser/.cache/huggingface/datasets/esnli/plain_text/0.0.2/262495ebbd9e71ec9b0c37a93e378f1b353dc28bb904305e011506792a02996b)


In [40]:
model_path = f"../models/roberta-base-mnli-hypothesis-only/{seed}/"

model = RobertaForSequenceClassification.from_pretrained(model_path)
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

Some weights of the model checkpoint at ../models/roberta-base-mnli-hypothesis-only/42/ were not used when initializing RobertaModel: ['classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at ../models/roberta-base-mnli-hypothesis-only/42/ and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it f

In [41]:
explainers = [
    SHAPExplainer(model, tokenizer),
    LIMEExplainer(model, tokenizer),
    IntegratedGradientExplainer(model, tokenizer, multiply_by_inputs=False),
    IntegratedGradientExplainer(model, tokenizer, multiply_by_inputs=True)
]
bench = Benchmark(model, tokenizer, explainers=explainers)

In [42]:
BPE_DIVIDER_TOKEN = "Ġ"
def create_one_hot(hypothesis, esnli_explanation):
    tokens = tokenizer.tokenize(hypothesis)
    one_hot_encoding = []
    current_word_index = 0
    for token in tokens:
        if token.startswith(BPE_DIVIDER_TOKEN):
            current_word_index+=1
        token_code = 1 if str(current_word_index) in esnli_explanation else 0
        one_hot_encoding.append(token_code)
    return one_hot_encoding

In [43]:
def calculate_evaluations(row):
    hypothesis = row["hypothesis"]
    label = row["label"]
    esnli_explanation = row["sentence2_highlighted_1"]
    rationale = create_one_hot(hypothesis, esnli_explanation)
    ferret_explanations = bench.explain(hypothesis, target=label)
    
    row["ferret_explanations"] = [explanation.scores for explanation in ferret_explanations]
    evaluations = [bench.evaluate_explanation(explanation, label, rationale) for explanation in ferret_explanations]
    row["evaluations"] = [[evaluation_score.score for evaluation_score in evaluation.evaluation_scores] for evaluation in evaluations]
    return row

# The item ferret_explanations contains the scores for the explanations of Shap, LIME, Integrated Gradient and Integrated Gradients multiplying inputs
# The item evaluations contains scores for each explanation for AOPC comprehensiveness, AOPC sufficiency, Tau LOO, AUPRC, Token F1, Token IOU
# https://github.com/g8a9/ferret/blob/b1343501db6367ca9048283862f8e0763c72e4ba/ferret/benchmark.py#L88

In [44]:
def calculate_score(dataset, row_value_selector, score):
    # Use python map, as we want values, not table entries
    value_dataset = map(row_value_selector, dataset)
    # Score takes an iter and returns a value
    return score(value_dataset)

def calculate_dataset_scores(dataset):
    scores = [sum, min, max, median, mean]
    score_names = ["sum", "min", "max", "median", "mean"]
    explainer_names = ["shap", "lime", "integrated_gradient", "integrated_gradient_multiply_by_inputs"]
    evaluator_names = ["comprehensiveness", "sufficiency", "loo", "auprc_plausibility", "tokenf1_plausibility", "tokeniou_plausibility"]
    dataset_scores = {}
    for i, explainer_name in zip(range(4), explainer_names): # 4 explainers
        dataset_scores[explainer_name] = {}
        for j, evaluator_name in zip(range(6), evaluator_names): # 6 evaluators for each explainers
            row_value_selector = lambda row: row["evaluations"][i][j]
            evaluation_scores = [calculate_score(dataset, row_value_selector, score) for score in scores]
            named_evaluation_scores = {name: value for (name, value) in zip(score_names, evaluation_scores)}
            dataset_scores[explainer_name][evaluator_name] = named_evaluation_scores

    return dataset_scores


In [45]:
esnli_with_evaluations = esnli.map(calculate_evaluations)

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Explainer:   0%|          | 0/4 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

Evaluator:   0%|          | 0/6 [00:00<?, ?it/s]

In [46]:
dataset_scores = calculate_dataset_scores(esnli_with_evaluations)

In [48]:
esnli_with_evaluations.save_to_disk(f"../datasets/esnli_evaluations_hypothesis_only_{seed}.hf")
with open(f"../datasets/esnli_evaluation_scores_hypothesis_only_{seed}.json", "w+") as f:
    f.write(json.dumps(dataset_scores))

Saving the dataset (0/1 shards):   0%|          | 0/30 [00:00<?, ? examples/s]