In [4]:
import requests,bs4,lxml,spacy,os,pickle,math,faiss
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
from numpy import dot
from numpy.linalg import norm

In [5]:
_DOWNLOAD_URL = "https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WSCollection.xml"
cskg_embeddings_file="./cskg_embedding/cskg_embeddings.txt"
cskg_connected_file="../kg-bert/data/cskg/cskg_connected.tsv"

In [6]:
# prepare model
model = SentenceTransformer('nli-bert-large')
nlp = spacy.load("en_core_web_sm")

## Download WSC Dataset

The code is based on the tensorflow. 
Link: https://www.tensorflow.org/datasets/catalog/wsc273  
Source code: https://github.com/tensorflow/datasets/blob/45ba108ef87162b17335245fc97cabc75efcfa54/tensorflow_datasets/text/wsc273/wsc273.py#L115

In [7]:
def normalize_text(text):
    text = text.strip()
    # Correct a misspell.
    text = text.replace("recieved", "received")
    text = text.replace("\n", " ")
    text = text.replace("  ", " ")
    return text

def normalize_cap(option, pron):
    """Normalize the capitalization of the option according to the pronoun."""
    cap_tuples = [
      ("The", "the"), ("His", "his"), ("My", "my"),
      ("Her", "her"), ("Their", "their"), ("An", "an"), ("A", "a")]
    uncap_dict = dict(cap_tuples)
    cap_dict = dict([(t[1], t[0]) for t in cap_tuples])
    words = option.split(" ")
    first_word = words[0]
    if pron[0].islower():
        first_word = uncap_dict.get(first_word, first_word)
    else:
        first_word = cap_dict.get(first_word, first_word)
    words[0] = first_word
    option = " ".join(words)
    return option

def parse_wsc273_xml(xml_data):
    """Parse the XML file containing WSC273 examples."""
    soup = bs4.BeautifulSoup(xml_data, "lxml")
    schemas = soup.find_all("schema")
    # Only the first 273 examples are included in WSC273.
    for i, schema in enumerate(schemas[:273]):
        txt1 = schema.find_all("txt1")[0].get_text()
        txt1 = normalize_text(txt1)
        txt2 = schema.find_all("txt2")[0].get_text()
        txt2 = normalize_text(txt2)
        pron = schema.find_all("pron")[0].get_text()
        pron = normalize_text(pron)
        answers = [ans.get_text().strip() for ans in schema.find_all("answer")]
        normalized_answers = [normalize_cap(ans, pron) for ans in answers]
        assert len(answers) == 2
        choice = schema.find_all("correctanswer")[0].get_text().strip()
        label = {"A": 0, "B": 1}[choice[0]]
        if len(txt2) == 1:
            # If there is only one punctuation left after the pronoun,
            # then no space should be inserted.
            text = f"{txt1} {pron}{txt2}"
        else:
            text = f"{txt1} {pron} {txt2}"
        pronoun_text = pron
        pronoun_start = len(txt1) + 1
        pronoun_end = len(txt1) + len(pron) + 1
        example = dict(
            text=text,
            pronoun_text=pronoun_text,
            pronoun_start=pronoun_start,
            pronoun_end=pronoun_end,
            option1=answers[0],
            option2=answers[1],
            option1_normalized=normalized_answers[0],
            option2_normalized=normalized_answers[1],
            label=label,
            idx=i)
        assert text[pronoun_start:pronoun_end] == pronoun_text
        yield example

In [8]:
r = requests.get(_DOWNLOAD_URL)

In [9]:
examples=parse_wsc273_xml(r.text)

In [10]:
examples=list(parse_wsc273_xml(r.text))

## Analysis WSC

Extract property  
Example:  
'The city councilmen refused the demonstrators a permit because they feared violence.'  
Extracted property: refused a permit, feared violence  
Hyp sentence: feared violence casuses refused a permit

In [11]:
class token_text():
    # build a class to generate empty token
    def __init__(self, text):
        self.text=text

def find_end_tree(root,res=[], left_=False, right_=False):
    # find the end token of dependency Tree.
    # left_ means find left hand side children
    # right_ means find right hand side children
    if root.n_lefts*left_ + root.n_rights*right_ > 0:
        if left_ and right_:
            res.append(root)
            for child in root.children:
                find_end_tree(child, res=res, left_=left_,right_=right_)
        elif left_:
            res.append(root)
            for child in root.lefts:
                find_end_tree(child, res=res, left_=left_,right_=right_)
        elif right_:
            res.append(root)
            for child in root.rights:
                find_end_tree(child, res=res, left_=left_,right_=right_)
    else:
        res.append(root)
    
    return res        
        
def AddPeriod(sent_text):
    # add period to sentences
    if sent_text[-1]!=".":
        sent_text+="."
    return sent_text
def TokensByText(sent, label,doc):
    # find tokens by label name
    blacklist = {
      "the","his","my","her", "their", "an", "a"}
    
    stack=[]
    res_stacks=[]
    for token in sent:
        token_label=token.text
        token_lemma=token.lemma_
        if token_label in blacklist or token_lemma in blacklist:
            continue
            
        if token_label in label or token_lemma in label:
            if stack:
                # check whether idx is continuous
                last_token_idx=stack[-1].i
                if token.i-last_token_idx==1:
                    # continous
                    stack.append(token)
                else:
                    stack=[]
                    stack.append(token)
            else:
                stack.append(token)
            
            # check whether tokens is found.
            tokens_label=doc[stack[0].i:stack[-1].i+1]
            if tokens_label.text==label:
                res_stacks.append(stack)
    if not res_stacks:    
        return stack
    else:
        for stack_ in res_stacks:
            if SubjectIdentifier(stack_,sent):
                return stack_
            
        return res_stacks[-1]

def ObjModify(obj):
    blacklist = {
      "the","his","my","her", "their", "an", "a"}
    obj=obj.lower()
    obj_list=obj.split(" ")
    tmp=[]
    for token in obj_list:
        if token in blacklist:
            continue
        else:
            tmp.append(token)
            
    return " ".join(tmp)

def SubjectIdentifier(tokens,sent):
    # identify tokens are subject or object
    # use the last token of tokens
    
    SUBJECTS = ["nsubj", "nsubjpass", "csubj", "csubjpass", "agent", "expl"]
    token = tokens[-1]
    
    # token is root:
    if token==sent.root:
        return True
    
    
    if token.dep_ in SUBJECTS:
        return True
    
    return False

def SubjectFindProp(obj_tokens,other_tokens,doc):
    # find property of subject
    obj_last_token=obj_tokens[-1]
    obj_text=" ".join([_.text for _ in obj_tokens])
    other_text=" ".join([_.text for _ in other_tokens])
    anc=list(obj_last_token.ancestors)
    if len(anc)>0:
        root_token=list(obj_last_token.ancestors)[0]
    else:
        root_token=obj_last_token
    root_idx=root_token.i
    mindis_token=[token_text(""),float("inf")]
    for child in root_token.rights:
        idx=child.idx
        dif=idx-root_idx
        if child.text in other_text:
            continue
        if dif < mindis_token[1]:
            mindis_token=[child,dif]
    if mindis_token[0].text=="":
        most_right=root_idx
    else:
        most_right=find_max_idx(mindis_token[0],res=[])
    property_tokens=doc[root_idx:most_right+1]
    skip_blacklist = {
      "the","his","my","her", "their", "an", "a"}
    tmp=[]
    for token_ in property_tokens:
        if token_.text in skip_blacklist:
            continue
        else:
            tmp.append(token_)
    property_tokens=tmp
    break_blacklist = {
      "that","for","so","because", "and", "until","but","when"}
    
    tmp=[]
    for token_ in property_tokens:
        if token_.text in break_blacklist:
            break
        else:
            tmp.append(token_)
    property_tokens=tmp
    property_text=" ".join([_.text for _ in property_tokens]).replace(obj_text,"").replace(other_text,"").strip()
    property_text=property_text.replace("  "," ")
    return property_text

def find_max_idx(token, res=[]):
    # find the most right subtoken idx 
    if token.n_rights> 0:
        for child in token.rights:
            res.append(child.i)
            find_max_idx(child, res=res)
    else:
        res.append(token.i)
    return max(res)

def find_min_idx(token, res=[]):
    # find the most right subtoken idx 
    if token.n_lefts> 0:
        for child in token.lefts:
            res.append(child.i)
            find_min_idx(child, res=res)
    else:
        res.append(token.i)
    return min(res)

def LastRootDetection(sent):
    root_token=sent[-2]
    for token in sent:
        anc=list(token.ancestors)
        token_idx=token.i
        if token.pos_=="PUNCT" or token.pos_=="ADV":
            continue
            
        if anc:
            anc_i=anc[0].i
            diff=token_idx-anc_i
            
            if diff >= 0.3*len(sent):
                root_token=token
        else:
            root_token=token
                
    return root_token

def elements_extraction_other(example):
    # new information extraction rule for other format
    # add punction
    
    sent_text=example["text"]
    # get the text of two object
    obj1=example["option1"]
    obj2=example["option2"]
    
    sent_text=AddPeriod(sent_text)
    
    doc=nlp(sent_text)
    # find entity label.
    ents_=list(doc.ents)
    
    person_num=0
        
    for ent_ in ents_:
        if ent_.text in {obj1,obj2} or obj1 in ent_.text or obj2 in ent_.text:
            if ent_.label_=="PERSON":
                person_num+=1
    
    # find property 2
    last_sent=list(doc.sents)[-1]
    last_root=LastRootDetection(last_sent)
    start_=last_root.i
    for token_ in doc[last_root.i:]:
        if token_.pos_=="AUX" or token_.lemma_ in {"is","are","be"}:
            start_=token_.i
            continue
        
        if token_.text in {"he","she","they"} and SubjectIdentifier([token_],last_sent):
            start_=token_.i+1
            continue
    prop_last=(" ").join([_.text for _ in doc[start_:] if _.pos_ not in {"PUNCT"} and _.text not in {"so","too","very"}])

    if person_num>0:
        reverse={"that","because","for"}
        # example: Joan made sure to thank Susan for all the help she had recieved.
        # recieved causes thank all the help
        doc=nlp(sent_text)
        obj1_tokens=[]
        obj2_tokens=[]
        for _sent in doc.sents:
            if not obj1_tokens:
                obj1_tokens=TokensByText(_sent,obj1,doc)
                if obj1_tokens:
                    sent1=_sent
            if not obj2_tokens:
                obj2_tokens=TokensByText(_sent,obj2,doc)
                if obj2_tokens:
                    sent2=_sent
            if not obj1_tokens or not obj2_tokens:
                return None
            
        # use obj tokens to decide it is subject or object
        obj1_pos=SubjectIdentifier(obj1_tokens,sent1)
        obj2_pos=SubjectIdentifier(obj2_tokens,sent2)
        if obj1_pos and obj2_pos:
            # both are subject
            # A is property, B is property
            obj1_prop=SubjectFindProp(obj1_tokens,obj2_tokens,doc)
            obj2_prop=SubjectFindProp(obj2_tokens,obj1_tokens,doc)
            #sent1= obj1_prop+" "+prop_last
            #sent2= obj2_prop+" "+prop_last
            for token in doc:
                for item in reverse:
                    if item in token.text:
                        sent1= prop_last+" causes "+obj1_prop
                        sent2= prop_last+" causes "+obj2_prop
                        return 0, sent1, sent2
                    else:
                        sent1= obj1_prop+" causes "+prop_last
                        sent2= obj2_prop+" causes "+prop_last
            return 0, sent1, sent2
        elif obj1_pos or obj2_pos:
            # A property B
            if obj1_pos:
                obj_prop=SubjectFindProp(obj1_tokens,obj2_tokens,doc)
                status=["obj1","obj2"]
            else:
                obj_prop=SubjectFindProp(obj2_tokens,obj1_tokens,doc)
                status=["obj2","obj1"]
            
            sent=""
            for token in doc:
                for item in reverse:
                    if item in token.text:
                        sent= prop_last+" causes "+obj_prop
                        return 1, sent
                    else:
                        sent= obj_prop+" causes "+prop_last
            return 1, sent
        else:
            obj1_prop=" ".join([_.text for _ in obj1_tokens])
            obj2_prop=" ".join([_.text for _ in obj2_tokens])
            return
    else:
        # example: The city councilmen refused the demonstrators a permit because they feared violence.
        # sentences: The city councilmen feared violence; the demonstrators feared violence
        sent1= obj1+" "+prop_last
        sent2= obj2+" "+prop_last
        
        return 2, sent1, sent2

In [12]:
embed1=model.encode("was upset causes comforted")
embed2=model.encode("relieve antonym upset")
sim=dot(embed1, embed2)/(norm(embed1)*norm(embed2))
sim

0.9122619

In [13]:
example_information =[]
with open("./cskg_embedding/wsc.txt","w") as f:
    for example in examples:
        sent=example["text"]
        tmp=elements_extraction_other(example)
        if tmp:
            if tmp[0]==0:
                hypo_sent=[tmp[1],tmp[2]]
            elif tmp[0]==1:
                hypo_sent=[tmp[1]]
            elif tmp[0]==2:
                hypo_sent=[tmp[1],tmp[2]]
        else:
            hypo_sent=[""]
        
        f.write("Sentence: ")
        f.write(sent+"\n")
        f.write("Hypo Sentence: ")
        f.write("\t".join(hypo_sent)+"\n")
        f.write("\n")
        
        example_information.append({"Sentence":sent,
                                "hypo sentences":hypo_sent})

## Load cskg embedding

In [14]:
# load cskg file

with open(cskg_connected_file, "r") as f:
    head = f.readline().strip().split("\t")
    
    # load lines only contain relation==HasProperty
    cskg_lines=[]
    
    for item in f:
        line = item.strip().split("\t")
        relation_id=line[2]
        cskg_lines.append(line)
        
cskg_sents=[]
for line in cskg_lines:
    cskg_sents.append(f"{line[4]} {line[6]} {line[5]}")

In [None]:
isfile_=os.path.isfile('cskg_model_embed.pickle')
if isfile_:
    with open('cskg_model_embed.pickle', 'rb') as handle:
        cskg_embed = pickle.load(handle)
else:
    # divide task into 1000chunks and run loop
    chunk_num = 1000
    chunks_size = math.ceil(len(cskg_sents)/1000)
    cskg_embed=np.empty((0,1024), dtype=np.float32)
    count = 0

    for num in tqdm(range(1,chunk_num+1)):
        end = num*chunks_size
        start = chunks_size*(num-1)
        temp_embed=model.encode(cskg_sents[start:end])
        cskg_embed=np.append(cskg_embed, temp_embed, axis=0)
        count += 1 
        
    # store file into desktop
    with open('cskg_model_embed.pickle', 'wb') as handle:
        pickle.dump(cskg_embed, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [247]:
cskg_embed=np.array([S/(math.sqrt(sum(S**2))) for S in tqdm(cskg_embed)])

100%|██████████| 6003237/6003237 [19:42<00:00, 5078.00it/s]


In [260]:
hyp_sents=[]
for line in example_information:
    sents=line["hypo sentences"]
    for sent in sents:
        hyp_sents.append(sent)
    
hyp_sents_embed=model.encode(hyp_sents)
hyp_sents_embed=np.array([S/(math.sqrt(sum(S**2))) for S in tqdm(hyp_sents_embed)])

# use faiss to find neareast
d= cskg_embed.shape[1]
index = faiss.IndexFlatL2(d)
index.add(cskg_embed)

min_sim=1
max_sim=0
# find the closest edges
k = 1
D, I = index.search(hyp_sents_embed, k)
for idx,embed in zip(I,hyp_sents_embed):
    idx=idx[0]
    embed1=cskg_embed[idx]
    embed2=embed
    similar=dot(embed1, embed2)/(norm(embed1)*norm(embed2))
    min_sim=min(min_sim,similar)
    max_sim=max(max_sim,similar)

100%|██████████| 446/446 [00:00<00:00, 3418.65it/s]


In [261]:
min_sim,max_sim

(0.7112264, 1.0)

In [262]:
threshold=(min_sim+max_sim)/2

In [274]:
correct=0
with open("./cskg_embedding/wsc.txt", "w") as f:
    f.write(f"The threshold of cosine similar is: {threshold}\n")
    f.write("######################\n")
    count=0
    QA_num=0
    for idx in range(len(examples)):
        example=examples[idx]
        line=example_information[idx]
        sent=example["text"]
        hypo_sents=line["hypo sentences"]
        Most_Similar_CSKG=[]
        sims=[]
        for hypo_sent in hypo_sents:
            idx=I[count][0]
            Most_Similar_CSKG.append(cskg_sents[idx])
            embed1=cskg_embed[idx]
            embed2=hyp_sents_embed[count]
            similar=dot(embed1, embed2)/(norm(embed1)*norm(embed2))
            sims.append(str(round(similar,2)))
            count+=1
        QA_num+=1
        if len(sims)==1:
            ground=example["label"]
            if float(sims[0]) >= threshold:
                prediction=0
                if ground == 0:
                    correct+=1
            else:
                prediction=1
                if ground==1:
                    correct+=1
        elif len(sims)==2:
            ground=example["label"]
            if float(sims[0]) >= float(sims[1]):
                prediction=0
                if ground ==0:
                    correct+=1
            else:
                prediction=1
                if ground == 1:
                    correct+=1
                    
        f.write(str(QA_num)+")."+sent+"\n")
        f.write(f"Hypo Sentences:{'; '.join(hypo_sents)}\n")
        f.write(f"Cosine Similarity:{'; '.join(sims)}\n")
        f.write(f"Most Similar CSKG Line:{'; '.join(Most_Similar_CSKG)}\n")
        f.write(f"Ground:{ground}\n")
        f.write(f"Prediction:{prediction}\n")
        f.write("\n")
        

In [275]:
correct/len(examples)

0.4945054945054945

In [265]:
Most_Similar_CSKG

['counterfeit watch subclass of counterfeit consumer good',
 "loses wallet|personx loses persony's wallet effect on person x will be blamed"]

In [51]:
embed=model.encode("abdsa")

In [48]:
hypo_sents_embed.shape

(1024,)

In [201]:
doc=nlp("bob paid for charlie's college education. He is very generous.")
spacy.displacy.render(doc, style="dep")