In [10]:
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score
from transformers import RobertaTokenizer, RobertaForMaskedLM

In [2]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMaskedLM.from_pretrained('roberta-base')

In [26]:
sick = load_dataset("sick", split="validation")
multinli_matched = load_dataset("multi_nli", split="validation_matched")
multinli_mismatched = load_dataset("multi_nli", split="validation_mismatched")
esnli = load_dataset("../datasets/esnli.py", split="validation")

Found cached dataset sick (/home/imger/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)
Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)
Found cached dataset esnli (/home/imger/.cache/huggingface/datasets/esnli/plain_text/0.0.2/262495ebbd9e71ec9b0c37a93e378f1b353dc28bb904305e011506792a02996b)


In [4]:
MASK = tokenizer("<mask>", add_special_tokens=False)["input_ids"][0]
ENTAILMENT_IDS = tokenizer.convert_tokens_to_ids(["as", "because", "so"])
NEUTRAL_IDS = tokenizer.convert_tokens_to_ids(["and", "also", "or"])
CONTRADICTION_IDS = tokenizer.convert_tokens_to_ids(["but", "although", "still"])

def zero_shot_classify(premise, hypothesis):
    hypothesis = hypothesis.split(" ")
    hypothesis[0] = hypothesis[0].lower()
    hypothesis = " ".join(hypothesis)
    combined = f"{premise} <mask> {hypothesis}"
    tokenized, attention_mask = tokenizer(combined, return_tensors="pt", add_special_tokens=True).values()

    logits = model(tokenized, attention_mask=attention_mask)["logits"]
    mask_idx = tokenized[0].tolist().index(MASK)
    entailment_score = sum(logits[0, mask_idx, ENTAILMENT_IDS])
    neutral_score = sum(logits[0, mask_idx, NEUTRAL_IDS])
    contradiction_score = sum(logits[0, mask_idx, CONTRADICTION_IDS])

    max_score = max(entailment_score, neutral_score, contradiction_score)
    if entailment_score == max_score:
        return 0
    if neutral_score == max_score:
        return 1
    if contradiction_score == max_score:
        return 2
    assert False, "This should never be reached"

In [29]:
def analyze(dataset, premise_key, hypothesis_key, label_key="label"):
    true, predicted = [], []
    for d in tqdm(dataset):
        true.append(d[label_key])
        predicted.append(zero_shot_classify(d[premise_key], d[hypothesis_key]))
    print(f"MCC: {matthews_corrcoef(true, predicted)}")
    print(f"F1: {f1_score(true, predicted, average='macro')}")
    print(f"Acc: {accuracy_score(true, predicted)}")
    print(f"BAcc: {balanced_accuracy_score(true, predicted)}")

In [30]:
analyze(sick, "sentence_A", "sentence_B")

100%|██████████| 495/495 [00:29<00:00, 17.05it/s]

MCC: 0.11693422964062707
F1: 0.40600730362958853
Acc: 0.4686868686868687
BAcc: 0.4088917615056656





In [31]:
analyze(multinli_matched, "premise", "hypothesis")

100%|██████████| 9815/9815 [11:11<00:00, 14.61it/s]


MCC: 0.07961488162656762
F1: 0.3687037677176715
Acc: 0.375649516046867
BAcc: 0.3824356409470984


In [32]:
analyze(multinli_mismatched, "premise", "hypothesis")

100%|██████████| 9832/9832 [11:54<00:00, 13.76it/s]

MCC: 0.07933894087538897
F1: 0.357109767038657
Acc: 0.3721521562245728
BAcc: 0.38034714567859346





In [33]:
analyze(esnli, "premise", "hypothesis")

100%|██████████| 9842/9842 [09:49<00:00, 16.68it/s]


MCC: 0.04946855672579988
F1: 0.2779206944750587
Acc: 0.35257061572851045
BAcc: 0.3552969403729247
