In [36]:
import torch

from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score
from transformers import RobertaTokenizer, RobertaForMaskedLM

In [2]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMaskedLM.from_pretrained('roberta-base')

In [26]:
sick = load_dataset("sick", split="validation")
multinli_matched = load_dataset("multi_nli", split="validation_matched")
multinli_mismatched = load_dataset("multi_nli", split="validation_mismatched")
esnli = load_dataset("../datasets/esnli.py", split="validation")

Found cached dataset sick (/home/imger/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)
Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)
Found cached dataset esnli (/home/imger/.cache/huggingface/datasets/esnli/plain_text/0.0.2/262495ebbd9e71ec9b0c37a93e378f1b353dc28bb904305e011506792a02996b)


In [163]:
MASK = tokenizer("<mask>", add_special_tokens=False)["input_ids"][0]

# as described in outline
# ENTAILMENT_WORDS = ["as", "because", "so"]
# NEUTRAL_WORDS = ["and", "also", "or"]
# CONTRADICTION_WORDS = ["but", "although", "still"]

# improved word choice:
# count most probable 5 words for sample and aggregate per label
# subtract amount of all other labels
# take words with biggest count for each label
# result => Words that are common for that label but uncommon for others
# done for a random subset of the training set of multinli
ENTAILMENT_WORDS = ["Also", "Yes", "More", "Certainly", "yes", "by", "Specifically", "Indeed", "Yeah"]
NEUTRAL_WORDS = ["Apparently", "Perhaps", "Clearly", "Obviously", "Presumably"]
CONTRADICTION_WORDS = ["Yet", "However", "but", "Unfortunately", "Otherwise", "Except", "no", "Nearly", "Currently", "Sadly", "Instead", "Not", "Previously", "Until"]

# simple words (surprisingly good)
# ENTAILMENT_WORDS = ["yes"]
# NEUTRAL_WORDS = ["maybe"]
# CONTRADICTION_WORDS = ["no"]

ENTAILMENT_IDS = tokenizer.convert_tokens_to_ids(["Ġ"+word for word in ENTAILMENT_WORDS])
NEUTRAL_IDS = tokenizer.convert_tokens_to_ids(["Ġ"+word for word in NEUTRAL_WORDS])
CONTRADICTION_IDS = tokenizer.convert_tokens_to_ids(["Ġ"+word for word in NEUTRAL_WORDS])

def zero_shot_classify(premise, hypothesis):
    hypothesis = hypothesis.split(" ")
    hypothesis[0] = hypothesis[0].lower()
    hypothesis = " ".join(hypothesis)
    combined = f"{premise} <mask> {hypothesis}"
    tokenized, attention_mask = tokenizer(combined, return_tensors="pt", add_special_tokens=True).values()

    logits = model(tokenized, attention_mask=attention_mask)["logits"]
    mask_idx = tokenized[0].tolist().index(MASK)
    entailment_score = max(logits[0, mask_idx, ENTAILMENT_IDS])
    neutral_score = max(logits[0, mask_idx, NEUTRAL_IDS])
    contradiction_score = max(logits[0, mask_idx, CONTRADICTION_IDS])

    max_score = max(entailment_score, neutral_score, contradiction_score)
    if entailment_score == max_score:
        return 0
    if neutral_score == max_score:
        return 1
    if contradiction_score == max_score:
        return 2
    assert False, "This should never be reached"

In [29]:
def analyze(dataset, premise_key, hypothesis_key, label_key="label"):
    true, predicted = [], []
    for d in tqdm(dataset):
        true.append(d[label_key])
        predicted.append(zero_shot_classify(d[premise_key], d[hypothesis_key]))
    print(f"MCC: {matthews_corrcoef(true, predicted)}")
    print(f"F1: {f1_score(true, predicted, average='macro')}")
    print(f"Acc: {accuracy_score(true, predicted)}")
    print(f"BAcc: {balanced_accuracy_score(true, predicted)}")

In [164]:
analyze(sick, "sentence_A", "sentence_B")

100%|██████████| 495/495 [00:24<00:00, 20.55it/s]


MCC: 0.04600834360286191
F1: 0.21167878243349944
Acc: 0.33535353535353535
BAcc: 0.3571825564708127


In [165]:
analyze(multinli_matched, "premise", "hypothesis")

100%|██████████| 9815/9815 [11:26<00:00, 14.30it/s]

MCC: 0.11392030968470503
F1: 0.31279251578936584
Acc: 0.4044829342842588
BAcc: 0.3946197962978763





In [166]:
analyze(multinli_mismatched, "premise", "hypothesis")

100%|██████████| 9832/9832 [11:01<00:00, 14.87it/s]

MCC: 0.1283559309201467
F1: 0.3214575127463627
Acc: 0.41120829943043125
BAcc: 0.4040290359228627





In [167]:
analyze(esnli, "premise", "hypothesis")

100%|██████████| 9842/9842 [09:17<00:00, 17.66it/s]


MCC: 0.13421134002556806
F1: 0.327133422273503
Acc: 0.4106888843730949
BAcc: 0.4097624903100461


### Analysis on which tokens are most prevalent for specific labels:

In [None]:
import copy

from itertools import islice
from collections import Counter

In [111]:
def predict_mask(premise, hypothesis):
    hypothesis = hypothesis.split(" ")
    hypothesis[0] = hypothesis[0].lower()
    hypothesis = " ".join(hypothesis)
    combined = f"{premise} <mask> {hypothesis}"
    tokenized, attention_mask = tokenizer(combined, return_tensors="pt", add_special_tokens=True).values()

    logits = model(tokenized, attention_mask=attention_mask)["logits"]
    mask_idx = tokenized[0].tolist().index(MASK)
    return tokenizer.convert_ids_to_tokens(torch.topk(logits[0, mask_idx], k=3).indices.tolist())

In [109]:
multinli_train = load_dataset("multi_nli", split="train")

Using the latest cached version of the module from /home/imger/.cache/huggingface/modules/datasets_modules/datasets/multi_nli/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39 (last modified on Sat Feb 25 17:09:18 2023) since it couldn't be found locally at multi_nli., or remotely on the Hugging Face Hub.
Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


In [116]:
tokens_by_label = {0: [], 1: [], 2: []}
size = 50_000
for d in tqdm(islice(multinli_train, size), total=size):
    tokens_by_label[d["label"]].extend(predict_mask(d["premise"], d["hypothesis"]))

100%|██████████| 50000/50000 [53:20<00:00, 15.62it/s]  


In [117]:
counters = {label : Counter(words) for label, words in tokens_by_label.items()}
counters_ = {label : copy.deepcopy(c) for label, c in counters.items()}

for label in counters_.keys():
    for word in counters_[label].keys():
        for other in counters.keys():
            if word not in counters[other] or other==label:
                continue
            counters_[label][word] -= counters[other][word]

for label, c in counters_.items():
    print(label, c.most_common()[:15])

0 [('Ġ,', 114), ('ĠAlso', 74), ('ĠYes', 71), ('ĠMore', 31), ('ĠCertainly', 29), ('Ġyes', 26), ('Ġby', 23), ('The', 21), ('ĠBelow', 19), ('ĠSpecifically', 18), ('ĠIndeed', 18), ('ĠHow', 17), ('Ġ:', 17), ('Ġ.', 15), ('ĠYeah', 15)]
1 [('ĠApparently', 77), ('ĠPerhaps', 33), ('ĠClearly', 30), ('ĠObviously', 25), ('ĠPresumably', 21), ('ĠPlus', 13), ('ĠAlready', 12), ('ĠAmong', 9), ('ĠTogether', 8), ('ĠMajor', 6), ('Ġdollars', 6), ('ĠContinuous', 5), ('ĠSure', 5), ('ĠDespite', 5), ('Ġensure', 5)]
2 [('ĠYet', 1040), ('ĠHowever', 668), ('Ġbut', 542), ('ĠUnfortunately', 373), ('ĠOtherwise', 206), ('ĠExcept', 151), ('ĠAlmost', 143), ('Ġno', 64), ('ĠNearly', 46), ('ĠCurrently', 43), ('ĠSadly', 43), ('ĠInstead', 30), ('ĠNot', 27), ('ĠPreviously', 27), ('ĠUntil', 27)]
