In [2]:
import pandas as pd
import json, os
from transformers import pipeline
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
claim_connective_premise = ["conversely", "whereas", "but", "however", "while", "although", "though", "meanwhile", 
                            "previously", "instead", "or", "and", "then", "when", "as", "for", "accordingly", 
                            "because", "with", "since", "as", "and", "when", "if", "but"]
premise_connective_claim = ["yet", "still", "nevertheless", "rather", "nonetheless", "therefore", "thereby", "so",
                            "hence", "consequently", "thus"]
unclear_order = ['additionally', 'after', 'afterward', 'also', 'alternatively', 'anyway', 'before', 'despite', 'else', 
                 'essentially', 'eventually', 'except', 'finally', 'further', 'furthermore', 'indeed', 'later', 'lest', 
                 'likewise', 'meantime', 'moreover', 'next', 'once', 'otherwise', 'overall', 'particularly', 'plus',
                 'regardless', 'separately', 'similarly', 'simultaneously', 'specifically', 'thereafter', 'till', 
                 'ultimately', 'unless', 'until', 'upon', 'whatever', 'whenever', 'without', 'earlier', 'nor'] # claim premise as default

In [3]:
with open("connectives_dict.json") as f:
    connective_dict = json.load(f)
connective_dict

{'all': {'attack': ['although',
   'though',
   'but',
   'nevertheless',
   'if',
   'previously',
   'when',
   'and',
   'then',
   'while',
   'however',
   'still',
   'or',
   'instead',
   'as',
   'nonetheless',
   'meanwhile',
   'yet',
   'rather',
   'besides',
   'nor',
   'whereas',
   'earlier',
   'conversely'],
  'support': ['but',
   'because',
   'if',
   'when',
   'and',
   'so',
   'then',
   'as',
   'since',
   'therefore',
   'thus',
   'thereby',
   'consequently',
   'hence',
   'accordingly',
   'for',
   'with',
   'given'],
  'both': ['but', 'if', 'when', 'and', 'then', 'as'],
  'all': ['once',
   'although',
   'though',
   'but',
   'because',
   'nevertheless',
   'before',
   'until',
   'if',
   'previously',
   'when',
   'and',
   'so',
   'then',
   'while',
   'however',
   'also',
   'after',
   'separately',
   'still',
   'or',
   'moreover',
   'instead',
   'as',
   'nonetheless',
   'unless',
   'meanwhile',
   'yet',
   'since',
   'rather',

In [3]:
with open("connectives_dict_reduced_BERT.json") as f:
    connective_dict_bert = json.load(f)

In [4]:
with open("connectives_dict_reduced-RoBERTa.json") as f:
    connective_dict_roberta = json.load(f)

In [5]:
with open("connectives_dict_reduced_XLM-RoBERTa.json") as f:
    connective_dict_clmroberta = json.load(f)

In [7]:
set(connective_dict_bert["all"]["support"]).symmetric_difference(set(connective_dict_roberta["all"]["support"]))

{'accordingly', 'consequently', 'hence', 'thereby', 'therefore', 'thus'}

In [8]:
def get_new_sent(text1, text2, pattern):
    if text1.strip()[-1] in [".", ",", "!", "?", ":", ";", "-"]:
        text1 = text1[:-1]
    # print(text1)
    if text2.strip().capitalize() == text2.strip():
        text2 = text2[0].strip().lower()+text2[1:].strip()  
        # todo remove strip here otherwise "a" at sentence beginning will be concatenated with next word
    # print(text2)
    # print(" ".join([text1.strip(), pattern, text2.strip().lower()]).strip())
    return " ".join([text1.strip(), pattern, text2.strip().lower()]).strip()

- removed in bert:
    ["moreover", "furthermore", "lest", "conversely", "when/then", "everytime"]
    

In [None]:
# connective_dict = connective_dict_bert
# connective_name = "all_conns_bert"
# model_list = ["distilbert-base-uncased", "bert-base-uncased", "bert-large-uncased"]

for connective_dict, connective_name, model_list in [(connective_dict_roberta, "all_conns_roberta", ["distilbert-base-uncased", "bert-base-uncased", "bert-large-uncased", "roberta-base", "xlm-roberta-base", "roberta-large", "xlm-roberta-large"])]:
    # (connective_dict_bert, "all_conns_bert", ["distilbert-base-uncased", "bert-base-uncased", "bert-large-uncased"]),
   #  (connective_dict_xlm-roberta, "all_conns_xlm-roberta", ["distilbert-base-uncased", "bert-base-uncased", "bert-large-uncased", "xlm-roberta-base", "xlm-roberta-large"])
    name = "all_conns"
    split_perc = "all"
    split_conn = "all"
    # for name, split_perc, split_conn in [("all_conns", "all", "all")]: # , ("all_attack_support", "all", "attack+support"),  ("all_attack", "all", "attack"),  ("all_support", "all", "support"),  ("all_both", "all", "both")]:
    for data_path in ["argmin_all.csv", "ibmcs_all.csv", "perspectrum_all.csv"]:
        data_name = data_path[:-4]
        df = pd.read_csv("data/"+data_path, index_col=0)
        output_data = []
        if split_conn == "attack+support":
            connectives = connective_dict[split_perc]["attack"] + connective_dict[split_perc]["support"]
        else:
            connectives = connective_dict[split_perc][split_conn]
        for model_name in model_list:  # , "roberta-base", "xlm-roberta-base"]:  # ["microsoft/deberta-base"]:  # ["distilbert-base-cased", "bert-base-cased", "roberta-base", "xlm-roberta-base", "bert-large-cased", "roberta-large", "xlm-roberta-large"]:
            classifier = pipeline("fill-mask", model=model_name, device=0)
            masked_token = classifier.tokenizer.mask_token
            print(data_name, connective_name, name, model_name)
            for i, row in df.iterrows():
                output_dict = row.to_dict()
                new_sent_clwp = get_new_sent(row["claim"], row["premise"], masked_token)
                new_sent_plwc = get_new_sent(row["premise"], row["claim"], masked_token)
                output_dict["input_clwp"] = new_sent_clwp
                output_dict["input_plwc"] = new_sent_plwc
                result_clwp = classifier(new_sent_clwp, targets=connectives, top_k=classifier.model.config.vocab_size)
                result_plwc = classifier(new_sent_plwc, targets=connectives, top_k=classifier.model.config.vocab_size)
                # connectives = set([item["token_str"] for item in result])
                # print(model_name, sorted(list(set(connective_dict[split_conn][split_perc]).difference(connectives))))
                for item in result_clwp:
                    if item["token_str"] in connectives and (item["token_str"] in claim_connective_premise or item["token_str"] in unclear_order):
                        # df.loc[i, model_name+"_"+item["token_str"]] =  item["score"]
                        output_dict[model_name+"_"+item["token_str"]] =  item["score"]
                for item in result_plwc:
                    if item["token_str"] in connectives and item["token_str"] in premise_connective_claim:
                        # df.loc[i, model_name+"_"+item["token_str"]] =  item["score"]
                        output_dict[model_name+"_"+item["token_str"]] =  item["score"]
                if i%1000 == 0:
                    print(i, len(df))
                output_data.append(output_dict)
            df = pd.DataFrame(output_data)
            output_data = []
        if not os.path.exists("features/"+connective_name+"/"+name):
            os.makedirs("features/"+connective_name+"/"+name)
        df.to_csv("features/"+connective_name+"/"+name+"/"+data_name+".csv", index=False)

argmin_all all_conns_bert all_conns distilbert-base-uncased
1000 11139
8000 11139
10000 11139
12000 11139
14000 11139
21000 11139
24000 11139


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


argmin_all all_conns_bert all_conns bert-base-uncased
0 11139
1000 11139
2000 11139
