In [4]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

In [6]:
device = torch.device("cuda:1")
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# 0: contradiction 1: entailment 2: neutral
model = AutoModelForSequenceClassification.from_pretrained('textattack/bert-base-uncased-snli')

In [100]:
from nltk.corpus import stopwords

def perturb_text(text):
    text_list = text.split(' ')
    text_perturbed = []
    for delete in text.split(' '):
        if delete.lower() in set(stopwords.words('english')):
            continue
        text_perturbed.append(
            (' '.join(
                ['[MASK]' if w == delete else w for w in text_list]), delete))

    return text_perturbed



In [101]:
import itertools


def exp(model, premise, hypothesis, k):
    full_inp = tokenizer(premise, text_pair=hypothesis, return_tensors='pt')
    logits = torch.softmax(model(**full_inp).logits[0], dim=-1)
    orig_confidence, target_class = logits.max(-1)
    target_class = target_class.item()
    orig_confidence = orig_confidence.item()

    # premise first
    perturbed_premise = perturb_text(premise)
    pre_confidences = []
    for sent, _ in perturbed_premise:
        inp = tokenizer(sent, text_pair=hypothesis, return_tensors='pt')
        conf = torch.softmax(model(**inp).logits[0], dim=-1)[target_class].item()
        conf = orig_confidence - conf
        pre_confidences.append(conf)
    
    # perturb hypothesis
    perturbed_hyp = perturb_text(hypothesis)
    hyp_confidences = []
    for sent, _ in perturbed_hyp:
        inp = tokenizer(premise, text_pair=sent, return_tensors='pt')
        conf = torch.softmax(model(**inp).logits[0], dim=-1)[target_class].item()
        conf = orig_confidence - conf
        hyp_confidences.append(conf)
    
    pre_topk = torch.tensor(pre_confidences).topk(k=min(k, len(pre_confidences)))[1].tolist()
    hyp_topk = torch.tensor(hyp_confidences).topk(k=min(k, len(hyp_confidences)))[1].tolist()
    topk_premises = [perturbed_premise[i] for i in pre_topk]
    topk_hyp = [perturbed_hyp[i] for i in hyp_topk]

    topk_pairs = list(itertools.product(topk_premises, topk_hyp))
    final_confidences = []
    for pair in topk_pairs:
        inp = tokenizer(pair[0][0], text_pair=pair[1][0], return_tensors='pt')
        conf = torch.softmax(model(**inp).logits[0], dim=-1)[target_class].item()
        conf = orig_confidence - conf
        final_confidences.append(conf)
    
    return target_class, orig_confidence, final_confidences, topk_pairs


In [108]:
# contradiction example
# premise = "I didn't think that the movie was that great."
# hypothesis = "The movie was excellent."
# [(didn't + great <-> excellent)]

# entailment example
# premise = 'At the other end of Pennsylvania Avenue, people began to line up for a White House tour.'	
# hypothesis = 'People formed a line at the end of Pennsylvania Avenue.'
# [(began to line up <-> formed a line)]

# neutral example
premise = "Your gift is appreciated by each and every student who will benefit from your generosity."	
hypothesis = "Hundreds of students will benefit from your generosity."
# [each and every <-> Hundreds of]


In [109]:
import numpy as np
def analyze_result(prediction, confidence, conf_drops, perturbations):
    print('premise:', premise)
    print('hypothesis:', hypothesis)
    print()
    class_map = ['contradiction', 'entailment', 'neutral']
    print(f'original prediction was {class_map[prediction]} / with confidence: {confidence}\n')
    conf_drops = np.array(conf_drops)
    idx = conf_drops.argsort(axis=0)[::-1]
    conf_drops = conf_drops[idx]
    perturbations = [perturbations[i] for i in idx]
    pert_sents = [(s[0][0], s[1][0]) for s in perturbations]
    pert_pre_words = set([s[0][1] for s in perturbations])
    pert_hyp_words = set([s[1][1] for s in perturbations])

    for i, (pert, conf) in enumerate(zip(pert_sents, conf_drops), 1):
        print(f'{i}. {pert} | -{conf}')

    print()
    print('premise:', pert_pre_words, '\nhypothesis:', pert_hyp_words)


In [110]:
analyze_result(*exp(model, premise, hypothesis, k=4))

premise: Your gift is appreciated by each and every student who will benefit from your generosity.
hypothesis: Hundreds of students will benefit from your generosity.

original prediction was neutral / with confidence: 0.9418114423751831

1. ('Your gift is appreciated by each and every student who will benefit from your [MASK]', '[MASK] of students will benefit from your generosity.') | -0.924758305773139
2. ('Your gift is [MASK] by each and every student who will benefit from your generosity.', '[MASK] of students will benefit from your generosity.') | -0.9190915487706661
3. ('Your gift is appreciated by each and every student who will [MASK] from your generosity.', '[MASK] of students will benefit from your generosity.') | -0.9189088735729456
4. ('Your [MASK] is appreciated by each and every student who will benefit from your generosity.', '[MASK] of students will benefit from your generosity.') | -0.918523658066988
5. ('Your gift is appreciated by each and every student who will [MA