In [None]:
import pandas as pd
from tqdm import tqdm
from transformers import pipeline
from transformers import AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification

In [None]:
def find_insert_text(str1, str2):
    str1_list = str1.split(' ')
    str2_list = str2.split(' ')
    i = 0
    j = 0
    res = ''
    for j in range(len(str2_list)):
        if str1_list[i] != str2_list[j]:
            res += str2_list[j] + ' '
        else:
            i += 1
        if i == len(str1_list):
            break
    return res

In [None]:
model = 'deberta'
if model == 't5':
    model_name = "PavanNeerudu/t5-base-finetuned-qqp"
    MODEL = AutoModelForSeq2SeqLM.from_pretrained("./t5")
    tokenizer = AutoTokenizer.from_pretrained("./t5")
    nlp = pipeline('text2text-generation', model=MODEL, tokenizer=tokenizer, device=0)
elif model == 'deberta':
    model_name = "Tomor0720/deberta-large-finetuned-qqp"
    MODEL = AutoModelForSequenceClassification.from_pretrained("./deberta")
    tokenizer = AutoTokenizer.from_pretrained("./deberta")
    nlp = pipeline('text-classification', model=MODEL, tokenizer=tokenizer, device=0)
else:
    raise Exception

In [None]:
df = pd.read_csv('./qqp_lego.tsv', sep='\t')

In [None]:
def t5_get_label(question1, question2, nlp):
    input_text = "qqp question1: " + question1 + "question2: " + question2
    res = nlp(input_text)
    if 'generated_text' in res[0]:
        return res[0]['generated_text']
    else:
        return ''

def deberta_get_label(question1, question2, nlp):
    input_text = question1 + " " + question2
    res = nlp(input_text)
    if res[0]['label'] == 'LABEL_0':
        return 'not_duplicate'
    elif res[0]['label'] == 'LABEL_1':
        return 'duplicate'
    else:
        return "ERROR"

def get_label(question1, question2, nlp):
    if model == 't5':
        return t5_get_label(question1, question2, nlp)
    elif model == 'deberta':
        return deberta_get_label(question1, question2, nlp)
    else:
        return ''

In [None]:
former_text = ''
row_list = []
res_df = pd.DataFrame(columns=['group_no', 'original_text' ,'text_a', 'insert_text_a','text_b', 'insert_text_b', 'wrong_reason'])
group_no = 1
group_add = False
for index, row in tqdm(df.iterrows()):
    if index == 0:
        former_text = row['text_a']
    if str(row['text_a']) == str(former_text):
        row_list.append(row)
    else:
        for i in range(len(row_list)-1):
            for j in range(i+1, len(row_list)):
                insert_text_a = find_insert_text(row_list[i]['text_a'], row_list[i]['text_b'])
                insert_text_b = find_insert_text(row_list[j]['text_a'], row_list[j]['text_b'])
                insert_same = get_label(insert_text_a, insert_text_b, nlp)
                context_same = get_label(row_list[i]['text_b'], row_list[j]['text_b'], nlp)
#                 print(insert_same, context_same)
                if insert_same != context_same:
                    group_add = True
                    if insert_same == 'duplicate':
                        new_row = {'group_no':group_no, 'original_text':row_list[i]['text_a'], 'text_a':row_list[i]['text_b'], 'insert_text_a':insert_text_a, 'text_b':row_list[j]['text_b'], 'insert_text_b':insert_text_b, 'wrong_reason':'same meaning insert, different meaning context'}
                        res_df.loc[len(res_df)] = new_row
                    else:
                        new_row = {'group_no':group_no, 'original_text':row_list[i]['text_a'], 'text_a':row_list[i]['text_b'], 'insert_text_a':insert_text_a, 'text_b':row_list[j]['text_b'], 'insert_text_b':insert_text_b, 'wrong_reason':'same meaning context, different meaning insert'}
                        res_df.loc[len(res_df)] = new_row
        if group_add:
            group_no = group_no + 1
            group_add = False
        former_text = str(row['text_a'])
        row_list = [row]
    former_text = str(row['text_a'])

In [None]:
res_df.to_csv(model + '_SSM_词义理解_result.csv')