In [1]:
import torch

from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score
from transformers import RobertaTokenizer, RobertaForMaskedLM

In [2]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMaskedLM.from_pretrained('roberta-base')

In [3]:
sick = load_dataset("sick", split="validation")
sick_test = load_dataset("sick", split="test")
multinli_matched = load_dataset("multi_nli", split="validation_matched")
multinli_mismatched = load_dataset("multi_nli", split="validation_mismatched")
esnli = load_dataset("../datasets/esnli.py", split="validation")

Found cached dataset sick (/home/imger/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
Found cached dataset sick (/home/imger/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)
Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


Downloading and preparing dataset esnli/plain_text to /home/imger/.cache/huggingface/datasets/esnli/plain_text/0.0.2/64fd5bee4cf6dcae59e2b804162412bbe9646aab00da31dac1cecc0ad4f798fd...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset esnli downloaded and prepared to /home/imger/.cache/huggingface/datasets/esnli/plain_text/0.0.2/64fd5bee4cf6dcae59e2b804162412bbe9646aab00da31dac1cecc0ad4f798fd. Subsequent calls will reuse this data.


In [4]:
MASK = tokenizer("<mask>", add_special_tokens=False)["input_ids"][0]

# as described in outline
A_PRIORI_ENTAILMENT_WORDS = ["as", "because", "so"]
A_PRIORI_NEUTRAL_WORDS = ["and", "also", "or"]
A_PRIORI_CONTRADICTION_WORDS = ["but", "although", "still"]

# improved word choice:
# count most probable 5 words for sample and aggregate per label
# subtract amount of all other labels
# take words with biggest count for each label
# result => Words that are common for that label but uncommon for others
# done for a random subset of the training set of multinli
TUNED_ENTAILMENT_WORDS = ["Also", "Yes", "More", "Certainly", "yes", "by", "Specifically", "Indeed", "Yeah"]
TUNED_NEUTRAL_WORDS = ["Apparently", "Perhaps", "Clearly", "Obviously", "Presumably"]
TUNED_CONTRADICTION_WORDS = ["Yet", "However", "but", "Unfortunately", "Otherwise", "Except", "no", "Nearly", "Currently", "Sadly", "Instead", "Not", "Previously", "Until"]

# simple words (surprisingly good)
SIMPLE_ENTAILMENT_WORDS = ["yes"]
SIMPLE_NEUTRAL_WORDS = ["maybe"]
SIMPLE_CONTRADICTION_WORDS = ["no"]

def get_class_token_map(entailment_words, neutral_words, contradiction_words):
    def to_tokens(words):
        return tokenizer.convert_tokens_to_ids(["Ġ"+word for word in words])
    return {
        "entailment": to_tokens(entailment_words),
        "neutral": to_tokens(neutral_words),
        "contradiction": to_tokens(contradiction_words),
    }

A_PRIORI_MAPPING = get_class_token_map(A_PRIORI_ENTAILMENT_WORDS, A_PRIORI_NEUTRAL_WORDS, A_PRIORI_CONTRADICTION_WORDS)
TUNED_MAPPING = get_class_token_map(TUNED_ENTAILMENT_WORDS, TUNED_NEUTRAL_WORDS, TUNED_CONTRADICTION_WORDS)
SIMPLE_MAPPING = get_class_token_map(SIMPLE_ENTAILMENT_WORDS, SIMPLE_NEUTRAL_WORDS, SIMPLE_CONTRADICTION_WORDS)

In [5]:
def zero_shot_classify(premise, hypothesis, class_token_map):
    hypothesis = hypothesis.split(" ")
    hypothesis[0] = hypothesis[0].lower()
    hypothesis = " ".join(hypothesis)
    combined = f"{premise} <mask> {hypothesis}"
    tokenized, attention_mask = tokenizer(combined, return_tensors="pt", add_special_tokens=True).values()

    logits = model(tokenized, attention_mask=attention_mask)["logits"]
    mask_idx = tokenized[0].tolist().index(MASK)
    entailment_score = max(logits[0, mask_idx, class_token_map["entailment"]])
    neutral_score = max(logits[0, mask_idx, class_token_map["neutral"]])
    contradiction_score = max(logits[0, mask_idx, class_token_map["contradiction"]])

    max_score = max(entailment_score, neutral_score, contradiction_score)
    if entailment_score == max_score:
        return 0
    if neutral_score == max_score:
        return 1
    if contradiction_score == max_score:
        return 2
    assert False, "This should never be reached"

In [6]:
def analyze(dataset, class_token_map, premise_key, hypothesis_key, label_key="label"):
    true, predicted = [], []
    for d in tqdm(dataset):
        true.append(d[label_key])
        predicted.append(zero_shot_classify(d[premise_key], d[hypothesis_key], class_token_map))
    print(f"MCC: {matthews_corrcoef(true, predicted)}")
    print(f"F1: {f1_score(true, predicted, average='macro')}")
    print(f"Acc: {accuracy_score(true, predicted)}")
    print(f"BAcc: {balanced_accuracy_score(true, predicted)}")

In [7]:
def analyze_all(class_token_map):
    print("Sick:")
    analyze(sick, class_token_map, "sentence_A", "sentence_B")
    print("MNLI matched:")
    analyze(multinli_matched, class_token_map, "premise", "hypothesis")
    print("MNLI mismatched:")
    analyze(multinli_mismatched, class_token_map, "premise", "hypothesis")
    print("e-SNLI:")
    analyze(esnli, class_token_map, "premise", "hypothesis")

In [8]:
def analyze_final(class_token_map):
    print("Sick:")
    analyze(sick_test, class_token_map, "sentence_A", "sentence_B")

In [13]:
analyze_all(A_PRIORI_MAPPING)

Sick:


100%|██████████| 495/495 [00:37<00:00, 13.36it/s]


MCC: 0.21537353051732913
F1: 0.5286962262075036
Acc: 0.5616161616161616
BAcc: 0.5393986829503846
MNLI matched:


100%|██████████| 9815/9815 [17:43<00:00,  9.23it/s]


MCC: 0.15749408293163084
F1: 0.3823974681088829
Acc: 0.41426388181355067
BAcc: 0.42598170989880807
MNLI mismatched:


100%|██████████| 9832/9832 [16:19<00:00, 10.04it/s]


MCC: 0.17839241642916967
F1: 0.38224980265188807
Acc: 0.42219283970707894
BAcc: 0.4341494540624504
e-SNLI:


100%|██████████| 9842/9842 [08:33<00:00, 19.16it/s]

MCC: 0.06294157916149738
F1: 0.30441411218464404
Acc: 0.36141028246291407
BAcc: 0.36397935474926113





In [14]:
analyze_all(TUNED_MAPPING)

Sick:


100%|██████████| 495/495 [00:23<00:00, 21.23it/s]


MCC: 0.303377996600952
F1: 0.3387291685966917
Acc: 0.3515151515151515
BAcc: 0.5651903106667279
MNLI matched:


100%|██████████| 9815/9815 [11:06<00:00, 14.73it/s]


MCC: 0.22224997277557923
F1: 0.43766190190135096
Acc: 0.4664289353031075
BAcc: 0.4666184177094926
MNLI mismatched:


100%|██████████| 9832/9832 [12:20<00:00, 13.29it/s]


MCC: 0.2682915953860687
F1: 0.48990167715633315
Acc: 0.5080349877949553
BAcc: 0.5052285742705646
e-SNLI:


100%|██████████| 9842/9842 [09:48<00:00, 16.72it/s]

MCC: 0.13765738262185065
F1: 0.4153408681449812
Acc: 0.42359276569802884
BAcc: 0.4229957980693493





In [15]:
analyze_all(SIMPLE_MAPPING)

Sick:


100%|██████████| 495/495 [00:47<00:00, 10.36it/s]


MCC: 0.22710935885844336
F1: 0.3699180148246075
Acc: 0.3878787878787879
BAcc: 0.4773443187557768
MNLI matched:


100%|██████████| 9815/9815 [13:10<00:00, 12.42it/s]


MCC: 0.1187912755358427
F1: 0.39347136401178967
Acc: 0.40254712175241975
BAcc: 0.4079821444797877
MNLI mismatched:


100%|██████████| 9832/9832 [10:34<00:00, 15.50it/s]


MCC: 0.12730684266796982
F1: 0.39108854741812943
Acc: 0.4040886899918633
BAcc: 0.4102116482839307
e-SNLI:


100%|██████████| 9842/9842 [08:26<00:00, 19.44it/s]


MCC: 0.07630587881015334
F1: 0.32685347403136694
Acc: 0.37187563503352977
BAcc: 0.37450973866468873


In [10]:
analyze_final(A_PRIORI_MAPPING)

Sick:




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

MCC: 0.2550667728662275
F1: 0.5430579558590263
Acc: 0.5866286180187525
BAcc: 0.5411303316401782





In [11]:
analyze_final(TUNED_MAPPING)

Sick:




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

MCC: 0.2910900703243763
F1: 0.34084086986356277
Acc: 0.3507949449653486
BAcc: 0.5573292723769949





In [12]:
analyze_final(SIMPLE_MAPPING)

Sick:




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

MCC: 0.20557769093489436
F1: 0.3439600572421164
Acc: 0.36220953933958416
BAcc: 0.45667709442325854





### Analysis on which tokens are most prevalent for specific labels:

In [None]:
import copy

from itertools import islice
from collections import Counter

In [111]:
def predict_mask(premise, hypothesis):
    hypothesis = hypothesis.split(" ")
    hypothesis[0] = hypothesis[0].lower()
    hypothesis = " ".join(hypothesis)
    combined = f"{premise} <mask> {hypothesis}"
    tokenized, attention_mask = tokenizer(combined, return_tensors="pt", add_special_tokens=True).values()

    logits = model(tokenized, attention_mask=attention_mask)["logits"]
    mask_idx = tokenized[0].tolist().index(MASK)
    return tokenizer.convert_ids_to_tokens(torch.topk(logits[0, mask_idx], k=3).indices.tolist())

In [109]:
multinli_train = load_dataset("multi_nli", split="train")

Using the latest cached version of the module from /home/imger/.cache/huggingface/modules/datasets_modules/datasets/multi_nli/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39 (last modified on Sat Feb 25 17:09:18 2023) since it couldn't be found locally at multi_nli., or remotely on the Hugging Face Hub.
Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


In [116]:
tokens_by_label = {0: [], 1: [], 2: []}
size = 50_000
for d in tqdm(islice(multinli_train, size), total=size):
    tokens_by_label[d["label"]].extend(predict_mask(d["premise"], d["hypothesis"]))

100%|██████████| 50000/50000 [53:20<00:00, 15.62it/s]  


In [117]:
counters = {label : Counter(words) for label, words in tokens_by_label.items()}
counters_ = {label : copy.deepcopy(c) for label, c in counters.items()}

for label in counters_.keys():
    for word in counters_[label].keys():
        for other in counters.keys():
            if word not in counters[other] or other==label:
                continue
            counters_[label][word] -= counters[other][word]

for label, c in counters_.items():
    print(label, c.most_common()[:15])

0 [('Ġ,', 114), ('ĠAlso', 74), ('ĠYes', 71), ('ĠMore', 31), ('ĠCertainly', 29), ('Ġyes', 26), ('Ġby', 23), ('The', 21), ('ĠBelow', 19), ('ĠSpecifically', 18), ('ĠIndeed', 18), ('ĠHow', 17), ('Ġ:', 17), ('Ġ.', 15), ('ĠYeah', 15)]
1 [('ĠApparently', 77), ('ĠPerhaps', 33), ('ĠClearly', 30), ('ĠObviously', 25), ('ĠPresumably', 21), ('ĠPlus', 13), ('ĠAlready', 12), ('ĠAmong', 9), ('ĠTogether', 8), ('ĠMajor', 6), ('Ġdollars', 6), ('ĠContinuous', 5), ('ĠSure', 5), ('ĠDespite', 5), ('Ġensure', 5)]
2 [('ĠYet', 1040), ('ĠHowever', 668), ('Ġbut', 542), ('ĠUnfortunately', 373), ('ĠOtherwise', 206), ('ĠExcept', 151), ('ĠAlmost', 143), ('Ġno', 64), ('ĠNearly', 46), ('ĠCurrently', 43), ('ĠSadly', 43), ('ĠInstead', 30), ('ĠNot', 27), ('ĠPreviously', 27), ('ĠUntil', 27)]
