In [1]:
# import
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os
cache_dir = "/shared/.cache/transformers"
gpu_device = 2
te_model = AutoModelForSequenceClassification.from_pretrained('joeddav/xlm-roberta-large-xnli', cache_dir=cache_dir).to('cuda:'+str(gpu_device))
tokenizer = AutoTokenizer.from_pretrained('joeddav/xlm-roberta-large-xnli', cache_dir=cache_dir)
os.chdir('/shared/lyuqing/probing_for_event/')

Some weights of the model checkpoint at joeddav/xlm-roberta-large-xnli were not used when initializing XLMRobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
def entailment(premise, hypothesis):

    x = tokenizer.encode(premise, hypothesis, return_tensors='pt', truncation='only_first').to('cuda:'+str(gpu_device))
    logits = te_model(x)[0]
    entail_contradiction_logits = logits
    probs = entail_contradiction_logits.softmax(1)
    prob_label_is_true = float(probs[:, 2])
    return prob_label_is_true

In [3]:
def load_trg_probe_lexicon(fr, level):
    lexicon = {}
    if level == 'fine':
        for line in fr:
            line = line.strip()
            if line:
                if line.isupper():
                    event_type = line
                else:
                    lexicon[event_type] = line
                    
    return lexicon
trg_probes_frn = 'source/lexicon/nli_topics.txt'
with open(trg_probes_frn, 'r') as fr:
    trg_probe_lexicon = load_trg_probe_lexicon(fr, 'fine')

In [4]:
def predict_event(sentence, trg_probe_lexicon):
    result_dict = {}
    for event_type in trg_probe_lexicon.keys():
        label = trg_probe_lexicon[event_type]
        hypothesis = f'This text is about {label}.'
        premise = sentence
        orig_entail_prob = entailment(premise, hypothesis)
#         if pair_premise_strategy:
#             sub_pattern = '\s?' + trigger_text + '\s?'
#             truncated_premise = re.sub(pattern=sub_pattern, string=premise, repl=' ').strip()
#             truncated_entail_prob = self.entailment(truncated_premise, hypothesis)
#             delta = orig_entail_prob - truncated_entail_prob

#         if self.pair_premise_strategy == 'max_delta':
#             result_dict[event_type] = delta
#         elif self.pair_premise_strategy == 'max_conf+delta':
#             result_dict[event_type] = orig_entail_prob + delta
#         elif self.pair_premise_strategy == None:
        result_dict[event_type] = orig_entail_prob

    sorted_res = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
    top_type, confidence = sorted_res[0][0], sorted_res[0][1]
    
    return top_type, confidence

In [7]:
# sentence = "击毙反政府武装分子" # "Killed the anti-government rebels"

# sentence = "平民暂时离开家园。" # "Civilians temporarily departed from home"

# sentence = "政府军缴获了一批武器。" # "Government forces seized a batch of weapons."

# sentence = "政府军方面有１２人受伤" # "12 people were injured in the government forces"

# sentence = "10多年来邻国内战" # The neighboring country has been in civil war for more than ten years

predict_event(sentence, trg_probe_lexicon)

('TRANSPORT', 0.9896987080574036)

In [14]:
# entailment testing
sentence = "10多年来邻国内战" # The neighboring country has been in civil war for more than ten years
hypothesis = "This text is about an attack or a war."
entailment(sentence, hypothesis)

0.8199309706687927