In [11]:
from tqdm.notebook import tqdm
import pandas as pd
pd.set_option("display.max_colwidth", None)
import json
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForSequenceClassification

In [12]:
df = pd.read_csv("../data/valid_df.csv", keep_default_na=False)

In [13]:
df.count()

Unnamed: 0      3458
arg_id          3458
key_point_id    3458
label           3458
argument        3458
topic           3458
stance          3458
key_point       3458
dtype: int64

In [14]:
entailment_model = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
tokenizer = AutoTokenizer.from_pretrained(entailment_model)
model = AutoModelForSequenceClassification.from_pretrained(entailment_model).to("cuda:0")

Some weights of the model checkpoint at ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
MAX_LENGTH=256

def compute_entailment(premise, hypothesis):
        tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis, max_length=MAX_LENGTH, return_token_type_ids=True, truncation=True)
        input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0).cuda()
        token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0).cuda()
        attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0).cuda()
        outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=None)
        predicted_probability  =torch.softmax(outputs[0], dim=1)[0].tolist()
        entailment_prob = predicted_probability[0]
        neutral_prob = predicted_probability[1]
        contradiction_prob = predicted_probability[2]
        #result = {'entail':entailment_prob, 'neutral':neutral_prob, 'contradict':contradiction_prob}
        return entailment_prob

In [16]:
arg = """school uniforms cut down on bulling and keep everyone the same."""
kp = """School uniform reduces bullying"""

In [17]:
compute_entailment(arg, kp)

0.9748885631561279

In [18]:
valid_df = df

In [19]:
def match_argument_with_keypoints(result, kp_dict, arg_dict):
    for arg_id, arg in arg_dict.items():
        result[arg_id] = {}
        for kp_id, kp in kp_dict.items():
            result[arg_id][kp_id] = compute_entailment(arg, kp)
    return result

In [20]:
argument_keypoints = {}
for topic in tqdm(valid_df.topic.unique()):
    for stance in [-1, 1]:
        topic_keypoint_ids = valid_df[(valid_df.topic==topic) & (valid_df.stance==stance)]['key_point_id'].tolist()
        topic_keypoints = valid_df[(valid_df.topic==topic) & (valid_df.stance==stance)]['key_point'].tolist()
        topic_kp_dict = dict(zip(topic_keypoint_ids, topic_keypoints))
        
        topic_argument_ids = valid_df[(valid_df.topic==topic) & (valid_df.stance==stance)]['arg_id'].tolist()
        topic_arguments = valid_df[(valid_df.topic==topic) & (valid_df.stance==stance)]['argument'].tolist()
        topic_arg_dict= dict(zip(topic_argument_ids, topic_arguments))
        # match 
        argument_keypoints = match_argument_with_keypoints(argument_keypoints, topic_kp_dict, topic_arg_dict)

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




In [21]:
json.dump(argument_keypoints, open("entailment_all_valid_predictions.json", "w",encoding='utf-8'))

In [22]:
! python3 ../src-py/track_1_kp_matching.py ../data entailment_all_valid_predictions.json our_valid

mAP strict= 0.6893564141187134 ; mAP relaxed = 0.8970236914678582
