In [1]:
#!python3 -m spacy download en_core_web_trf

In [2]:
from collections import defaultdict
import pandas as pd
import spacy
spacy.require_gpu()
from tqdm import tqdm
from paraphrase_metrics import metrics as pm
import language_tool_python
tool = language_tool_python.LanguageTool('en-US')

In [3]:
nlp = spacy.load("en_core_web_trf")

In [4]:
split = "train"
df = pd.read_csv("./mrpc/mrpc_"+split+".csv")
df["corrected_label"] = df.shape[0]*[None]
df.head()

Unnamed: 0.1,Unnamed: 0,s1,s2,label,corrected_label
0,0,"Amrozi accused his brother, whom he called ""th...","Referring to him as only ""the witness"", Amrozi...",1,
1,1,Yucaipa owned Dominick's before selling the ch...,Yucaipa bought Dominick's in 1995 for $693 mil...,0,
2,2,They had published an advertisement on the Int...,"On June 10, the ship's owners had published an...",1,
3,3,"Around 0335 GMT, Tab shares were up 19 cents, ...","Tab shares jumped 20 cents, or 4.6%, to set a ...",0,
4,4,"The stock rose $2.11, or about 11 percent, to ...",PG&E Corp. shares jumped $1.63 or 8 percent to...,1,


In [5]:
def rchop(s, suffix):
    if s.endswith(suffix):
        return s[:-len(suffix)]
    return s

def lchop(s, prefix):
    if s.startswith(prefix):
        return s[len(prefix):]
    return s

def collate_entities(doc):
    output = defaultdict(list)
    for ent in doc.ents:
        ent_text = ent.text
        if ent_text[0]=='"' or ent_text[0]=="'":
            ent_text = ent_text[1:]
        if ent_text[-1]=='"' or ent_text[-1]=="'":
            ent_text = ent_text[:-1]
        rstrip_list = ["percent", "%"]
        for s in rstrip_list:
            ent_text = rchop(ent_text, s)
        lstrip_list = ["the ", "$US", "$"]
        for s in lstrip_list:
            ent_text = lchop(ent_text, s)
        ent_text = ent_text.strip()
        if ent_text not in output[ent.label_]:
            output[ent.label_].append(ent_text)
    return dict(output)

replacements = []
black_list = ['US->U.S.', 'U.S.->US',
              'U.S.->United States', 'United States->U.S.',
              'U.N.->United Nations', 'United Nations->U.N.',
              'each year->annually', 'annually->each year',
              '10 years ago->a decade ago','a decade ago->10 years ago',
              'a year earlier->A year ago', 'A year ago->a year earlier',
              'a day->one day', 'one day->a day',
              'November 11->Nov. 11', 'Nov. 11->November 11',
              'more than 50 million->over 50 million', 'over 50 million->more than 50 million']

def preprocess(intext):
    outtext = intext.replace("’", "'")
    return outtext

def postprocess(intext):
    matches = tool.check(intext)
    categories = []
    mistakes = []
    corrections = []
    for rules in matches:
        if len(rules.replacements)>0:
            categories.append(rules.category)
            mistakes.append(intext[rules.offset:rules.errorLength+rules.offset])
            corrections.append(rules.replacements[0])
    corrections = list(zip(categories, mistakes, corrections))
    #print(intext)
    if len(corrections) > 0:
        for c in corrections:
            #print(c[0], ":", c[1], "->", c[2])
            if c[0] == "CASING":
                intext = list(intext)
                intext[rules.offset:rules.errorLength+rules.offset] = list(rules.replacements[0])
                intext = "".join(intext)
        #print(intext)
    return intext
    

def correct_row(row, verbose=False, replacement="longer"):
    global replacements, black_list
    s1_text, s2_text = preprocess(row["s1"]), preprocess(row["s2"])
    s1_text_og, s2_text_og = str(row["s1"]), str(row["s2"])
    s1, s2 = nlp(s1_text), nlp(s2_text)
    s1_ents = collate_entities(s1)
    s2_ents = collate_entities(s2)
    ent_types = list(set(list(s1_ents.keys())+list(s2_ents.keys())))
    correctable = True
    need_correction = False
    sub_source = []
    sub_target = []
    sub_type = []
    for et in ent_types:
        try:
            s1_count = len(s1_ents[et])
        except KeyError:
            s1_count = 0
        try:
            s2_count = len(s2_ents[et])
        except KeyError:
            s2_count = 0
        if s1_count != s2_count:
            correctable = False
            need_correction = True
        elif s1_count==s2_count==1:
            e1 = s1_ents[et][0]
            e2 = s2_ents[et][0]
            if e1.lower()!=e2.lower():
                need_correction = True
                if len(e1) < len(e2):
                    sub_source.append(e1)
                    sub_target.append(e2)
                else:
                    sub_source.append(e2)
                    sub_target.append(e1)
                sub_type.append(et)
        elif s1_count>1 or s2_count>1:
            ignore_list = set([])
            sub_list = set([])
            for ent in set(s1_ents[et]+s2_ents[et]):
                if ent in s1_ents[et] and ent in s2_ents[et]:
                    ignore_list.add(ent)
                else:
                    need_correction = True
                    sub_list.add(ent)
            if len(sub_list)==2:
                #print(sub_list)
                e1, e2 = sub_list
                need_correction = True
                if e1.lower()!=e2.lower():
                    if len(e1) < len(e2):
                        sub_source.append(e1)
                        sub_target.append(e2)
                    else:
                        sub_source.append(e2)
                        sub_target.append(e1)
                    sub_type.append(et)
            elif len(sub_list)==0:
                pass
            else:
                correctable = False
    #if len(sub_source)>2:
    #    verbose = True
    if correctable and len(sub_source)>0:
        if verbose:
            print(s1_text)
            print(s2_text)
        for s, t, et in zip(sub_source, sub_target, sub_type):
            if verbose:
                print(et, ":", s, "->", t)
            if replacement=="longer":
                replacement = s+"->"+t
                if replacement not in black_list:
                    if t not in s1_text:
                        s1_text = s1_text.replace(s, t)
                    if t not in s2_text:
                        s2_text = s2_text.replace(s, t)
                    replacements.append(replacement)
            else:
                replacement = t+"->"+s
                if replacement not in black_list:
                    s1_text = s1_text.replace(t, s)
                    s2_text = s2_text.replace(t, s)
                    replacements.append(t+"->"+s)
        if verbose:
            print("")
        if s1_text != s2_text:
            corrected = True
        else:
            corrected = False
            need_correction = False
    else:
        corrected = False
    if corrected:
        assert need_correction
        if s1_text.lower() == s1_text_og.lower() and s2_text.lower() == s2_text_og.lower():
            corrected = False
            need_correction = False
        s1_text = postprocess(s1_text)
        s2_text = postprocess(s2_text)
    return {"need_correction": need_correction,
            "corrected": corrected,
            "s1": s1_text,
            "s2": s2_text,
            "s1_og": s1_text_og,
            "s2_og": s2_text_og}

In [6]:
total = 0
cant_correct = 0
corrected = 0
no_need_correct = 0

og_s1 = []
og_s2 = []
new_s1 = []
new_s2 = []
og_label = []
new_label = []
remarks = []

for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    if row["label"] == 1:
        total += 1
        out = correct_row(row, verbose=False, replacement="longer")
        if out["need_correction"] and out["corrected"]:
            new_s1.append(out["s1"])
            new_s2.append(out["s2"])
            og_s1.append(out["s1_og"])
            og_s2.append(out["s2_og"])
            og_label.append(1)
            new_label.append(1)
            remarks.append("corrected")
            corrected += 1
            out = correct_row(row, verbose=False, replacement="shorter")
            new_s1.append(out["s1"])
            new_s2.append(out["s2"])
            og_s1.append(out["s1_og"])
            og_s2.append(out["s2_og"])
            og_label.append(1)
            new_label.append(1)
            remarks.append("corrected")
        elif out["need_correction"] and not out["corrected"]:
            new_s1.append(row["s1"])
            new_s2.append(row["s2"])
            og_s1.append(row["s1"])
            og_s2.append(row["s2"])
            s1, s2 = nlp(row["s1"]), nlp(row["s2"])
            ld = pm.ld(s1, s2)
            if ld > 0.1:
                og_label.append(1)
                new_label.append(0)
                cant_correct += 1
                remarks.append("can't correct")
            else:
                og_label.append(1)
                new_label.append(1)
                no_need_correct += 1
                remarks.append("no need to correct")
        else:
            new_s1.append(row["s1"])
            new_s2.append(row["s2"])
            og_s1.append(row["s1"])
            og_s2.append(row["s2"])
            og_label.append(1)
            new_label.append(1)
            no_need_correct += 1
            remarks.append("no need to correct")
    else:
        new_s1.append(row["s1"])
        new_s2.append(row["s2"])
        og_s1.append(row["s1"])
        og_s2.append(row["s2"])
        og_label.append(0)
        new_label.append(0)
        no_need_correct += 1
        remarks.append("no need to correct")

100%|██████████| 4076/4076 [01:50<00:00, 36.76it/s]


In [7]:
print(total)
print(no_need_correct)
print(corrected+cant_correct)
print(corrected/(corrected+cant_correct))
print(corrected)

2753
1978
2098
0.15872259294566254
333


In [8]:
#from collections import Counter
#replacements_counter = Counter(replacements)
#replacements_counter.most_common(100)

In [9]:
df_out = pd.DataFrame(list(zip(og_s1, og_s2, new_s1, new_s2, og_label, new_label, remarks)),
                    columns =["og_s1", "og_s2", "new_s1", "new_s2", "og_label", "new_label", "remarks"])

In [10]:
df_out.head(5)

Unnamed: 0,og_s1,og_s2,new_s1,new_s2,og_label,new_label,remarks
0,"Amrozi accused his brother, whom he called ""th...","Referring to him as only ""the witness"", Amrozi...","Amrozi accused his brother, whom he called ""th...","Referring to him as only ""the witness"", Amrozi...",1,1,no need to correct
1,Yucaipa owned Dominick's before selling the ch...,Yucaipa bought Dominick's in 1995 for $693 mil...,Yucaipa owned Dominick's before selling the ch...,Yucaipa bought Dominick's in 1995 for $693 mil...,0,0,no need to correct
2,They had published an advertisement on the Int...,"On June 10, the ship's owners had published an...",They had published an advertisement on the Int...,"On June 10, the ship's owners had published an...",1,1,no need to correct
3,"Around 0335 GMT, Tab shares were up 19 cents, ...","Tab shares jumped 20 cents, or 4.6%, to set a ...","Around 0335 GMT, Tab shares were up 19 cents, ...","Tab shares jumped 20 cents, or 4.6%, to set a ...",0,0,no need to correct
4,"The stock rose $2.11, or about 11 percent, to ...",PG&E Corp. shares jumped $1.63 or 8 percent to...,"The stock rose $2.11, or about 11 percent, to ...",PG&E Corp. shares jumped $1.63 or 8 percent to...,1,0,can't correct


In [11]:
df_out.tail(5)

Unnamed: 0,og_s1,og_s2,new_s1,new_s2,og_label,new_label,remarks
4404,"""At this point, Mr. Brando announced: 'Somebod...","Brando said that ""somebody ought to put a bull...","""At this point, Mr. Brando announced: 'Somebod...","Brando said that ""somebody ought to put a bull...",1,1,no need to correct
4405,"Martin, 58, will be freed today after serving ...",Martin served two thirds of a five-year senten...,"Martin, 58, will be freed today after serving ...",Martin served two thirds of a five-year senten...,0,0,no need to correct
4406,"""We have concluded that the outlook for price ...","In a statement, the ECB said the outlook for p...","""We have concluded that the outlook for price ...","In a statement, the ECB said the outlook for p...",1,0,can't correct
4407,The notification was first reported Friday by ...,MSNBC.com first reported the CIA request on Fr...,The notification was first reported Friday by ...,MSNBC.com first reported the CIA request on Fr...,1,0,can't correct
4408,The 30-year bond US30YT=RR rose 22/32 for a yi...,The 30-year bond US30YT=RR grew 1-3/32 for a y...,The 30-year bond US30YT=RR rose 22/32 for a yi...,The 30-year bond US30YT=RR grew 1-3/32 for a y...,0,0,no need to correct


In [12]:
df_out.to_csv("./mrpc_"+split+"_corrected.csv", index=False)

In [13]:
df_out.sum()

og_s1        Amrozi accused his brother, whom he called "th...
og_s2        Referring to him as only "the witness", Amrozi...
new_s1       Amrozi accused his brother, whom he called "th...
new_s2       Referring to him as only "the witness", Amrozi...
og_label                                                  3086
new_label                                                 1321
remarks      no need to correctno need to correctno need to...
dtype: object