In [67]:
from pathlib import Path
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss, math, random, re,pickle, os
from numpy import dot
from numpy.linalg import norm
import spacy
from collections import defaultdict
from nltk import pos_tag

In [48]:
# input
cskg_embeddings_file="./cskg_embedding/cskg_embeddings.txt"
cskg_connected_file="../kg-bert/data/cskg/cskg_connected.tsv"
RICA_file="./RICA/RICA_material_KnowledgeTable.csv"
sample_1k_lines="./1k_lines/test_sentences.txt"
ground_truth_file="./1k_lines/test_sentences_m.txt"

# output
temp_result = "./cskg_embedding/result.txt"

## Function

In [4]:
# design dependency rule
class token_text():
    # build a class to generate empty token
    def __init__(self, text):
        self.text=text

def AddPeriod(sent_text):
    # add period to sentences
    if sent_text[-1]!=".":
        sent_text+="."
    return sent_text

def subitem_depCheck(subs, require={}):
    # find the items having the required relation
    output_token=token_text("")
    for sub in subs:
        rel=sub.dep_
        
        if rel in require:
            output_token=sub
            break
            
    return output_token

def walk_tree(node, depth, depths={}):
    # walk through the sentence and find the depth
    depths[node] = depth
    if node.n_lefts + node.n_rights > 0:
        return [walk_tree(child, depth + 1,depths=depths) for child in node.children]

def find_max_idx(token, res=[]):
    # find the most right subtoken idx 
    if token.n_rights> 0:
        for child in token.rights:
            res.append(child.i)
            find_max_idx(child, res=res)
    else:
        res.append(token.i)
    return max(res)

def find_min_idx(token, res=[]):
    # find the most right subtoken idx 
    if token.n_lefts> 0:
        for child in token.lefts:
            res.append(child.i)
            find_min_idx(child, res=res)
    else:
        res.append(token.i)
    return min(res)
    
def find_end_tree(root,res=[], left_=False, right_=False):
    # find the end token of dependency Tree.
    # left_ means find left hand side children
    # right_ means find right hand side children
    if root.n_lefts*left_ + root.n_rights*right_ > 0:
        if left_ and right_:
            for child in root.children:
                find_end_tree(child, res=res, left_=left_,right_=right_)
        elif left_:
            for child in root.lefts:
                res.append(child)
                find_end_tree(child, res=res, left_=left_,right_=right_)
        elif right_:
            for child in root.rights:
                res.append(child)
                find_end_tree(child, res=res, left_=left_,right_=right_)
    else:
        res.append(root)
    
    return res

def SentencePartDetection(sent):
    # split sentence into different part
    # output the root token idx of each part
        # the span for different token
        
    doc=sent.doc
    spanLength_list=[]
    for token in sent:
        token_index=token.i
        for child in token.children:
            child_index = child.i
            
            if child.dep_=="punct":
                continue
            spanLength_list.append((abs(token_index-child_index),(token_index, child_index)))
            
    spanLength_list.sort(reverse=True)
    #print(spanLength_list)
    part_pos=set()
    
    depths={}
    #print(sent.root)
    [walk_tree(sent.root, 0, depths=depths) for sent in doc.sents]
    part_pos.add(sent.root.i)
    #print(depths)
    for span_length, tokens in spanLength_list:
        token1_index, token2_index=tokens
        token1=doc[token1_index]
        token2=doc[token2_index]
        #print(part_pos, token1,token2,depths[token2])
        if len(list(token2.children))<1 or depths[token2]>2:
            continue
        
        if token2.dep_ in {"conj","ccomp","advcl","advmod"} and token2.pos_ not in {"ADJ", "PUNCT"}:
            
            if token1_index not in part_pos and token2_index not in part_pos and len(part_pos)>=2:
                continue
            
            part_pos.add(token1_index)
            part_pos.add(token2_index)
        
        if len(part_pos)>=3:
            break
            
    part_pos=sorted(list(part_pos))
    return part_pos

def elements_extraction_isa(sent_text):
    # extract elemenmt from isa format: "X is A, Y is B, X is more/less than B"
    
    sent_text=AddPeriod(sent_text)
    
    # obtain parsed sentences
    doc=nlp(sent_text)
    sent=list(doc.sents)[0]
    
    # find root token idx for each sentence part
    part_pos=SentencePartDetection(sent)
    # final compare part
    compare_root=sent[part_pos[2]]
    
    # find reasoning part
    reasoning_first_root=sent[part_pos[0]]
    reasoning_second_root=sent[part_pos[1]]
    
    # generate object1 and object2 token candidates
    leftEnd_compare_tokens=find_end_tree(compare_root,left_=True,right_=False, res=[])
    leftEnd_first_tokens=find_end_tree(reasoning_first_root,left_=True,right_=False, res=[])
    leftEnd_second_tokens=find_end_tree(reasoning_second_root,left_=True,right_=False, res=[])
    
    # find object1 and object2
    object1, object2=token_text(""),token_text("")
    
    if len(leftEnd_first_tokens)==1:
        object2=leftEnd_first_tokens[0]
        
    else:
        for token in leftEnd_first_tokens:
            if token.pos_ == "NOUN":
                object2 = token
                
            if token.dep_ in {"nsubj","acomp"}:
                object2=token
                
            if list(token.ancestors)[0].text == "is":
                object2 = token
            
    if len(leftEnd_second_tokens)==1:
        object1=leftEnd_second_tokens[0]
        
    else:
        for token in leftEnd_second_tokens:
            if token.i < part_pos[0]:
                continue
                
            if token.pos_ == "NOUN":
                object1 = token
            
            if token.dep_ in {"nsubj","acomp"}:
                object1=token
                
            if list(token.ancestors)[0].text == "is":
                object1 = token
            
    # use than find one object candidates
    for token in sent:
        if token.text == "than":
            than_token=token
            
            than_child = list(than_token.children)
            than_anc= list(than_token.ancestors)
            
            if than_child:
                object_temp =than_child[0]
                
            else:
                object_temp=than_anc[0]
                
    objects_text=[object1.text, object2.text]
    
    # replace empty object result with object candidates found by than
    if object_temp.text in objects_text:
        pass
    else:
        if objects_text[0]=="":
            object1=object_temp
        else:
            object2=object_temp
            
    # find the property of each objects
    rightEnd_first_tokens=find_end_tree(reasoning_first_root,left_=False,right_=True, res=[])
    rightEnd_second_tokens=find_end_tree(reasoning_second_root,left_=False,right_=True, res=[])
    
    max_dep_obj2=0
    
    for token in rightEnd_first_tokens:
        idx = token.i
        if token.i > part_pos[1]:
            continue
            
        else:
            max_dep_obj2=max(max_dep_obj2, idx)
            
    max_dep_obj1=0
    for token in rightEnd_second_tokens:
        idx = token.i
        if token.i > part_pos[2]:
            continue
            
        else:
            max_dep_obj1=max(max_dep_obj1, idx)
            
    obj2_property=sent[part_pos[0]+1:max_dep_obj2+1].text
    obj1_property=sent[part_pos[1]+1:max_dep_obj1+1].text
    
    # we should knwo obj2_property is the property of obj1 or obj2
    # already know obj2_property apears ealier than obj1_property in the sentence.
    object1_loc=sent.text.find(object1.text)
    object2_loc=sent.text.find(object2.text)
    
    if object1_loc <= object2_loc:
        temp=obj2_property
        obj2_property=obj1_property
        obj1_property=temp
    else:
        pass
    
    # find compare aspect
    aspect_token=subitem_depCheck(compare_root.rights, require={"acomp","attr"})
    
    if aspect_token.text:
        aspect_index=aspect_token.i
        for left in aspect_token.lefts:
            rel=left.dep_
            if rel not in ["advmod"] and left.i < aspect_index:
                aspect_index= left.i

        aspect_span= sent[aspect_index:]
        aspect_text=aspect_span.text.split(" than ")[0]
    else:
        aspect_span = sent[part_pos[2]+1:]
        aspect_text=aspect_span.text.split(" than ")[0]
        
    # find ground truth is more or less:
    more=0
    less=0
    reverse= 1
    MoreOrLess=0
    for token in sent[part_pos[2]:]:
        if token.text == "more":
            more +=1
            
        elif token.text == "less":
            less += 1
            
        elif token.text == "not" or token.text == "no":
            reverse=-1
        
    if "er " in sent[part_pos[2]:].text:
        more += 1
    
    MoreOrLess = reverse*(more-less)>0
    return object1,object2,obj1_property.split(",")[0],obj2_property.strip(",").split(",")[0],aspect_text, MoreOrLess*"more"+(1-MoreOrLess)*"less"

def elements_extraction_other(sent_text):
    # new information extraction rule for other format
    # add punction
    sent_text=AddPeriod(sent_text)
    
    doc=nlp(sent_text)
    sent=list(doc.sents)[-1]
    
    # find root token idx for each sentence part
    part_pos=SentencePartDetection(sent)
    
    # without than pick the last token
    for token in list(sent)[::-1]:
        if token.pos_!="PUNCT":
            object_2=token
            break
    
    # find object2 by than
    for token in list(sent)[::-1]:
        #print(token)    
        if token.text == "than":
            than_token=token
            
            than_child = list(than_token.children)
            than_anc= list(than_token.ancestors)
            
            if than_child:
                for token in than_child:
                    if token.pos_!="PUNCT":
                        object_2 =token
                        break
                break
            else:
                for token in than_child:
                    if token.pos_!="PUNCT":
                        object_2 =token
                        break
                break
    
    # find compare root 
    compare_root=doc[part_pos[-1]]
    
    # find the left token of compare root
    leftEnd_compare_tokens=find_end_tree(compare_root,left_=True,right_=False, res=[])
    
    object_1=token_text(None)
    for token in leftEnd_compare_tokens:
        # limitation
        #print(token)
        if len(part_pos)>1 and token.i < part_pos[-2]:
            continue
        
        if token.pos_ == "PUNCT":
            continue
            
        # judege whether it is object1
        #print(token)
        if token.pos_ == "NOUN":
            object_1 = token
            break
            
        if token.dep_ in {"nsubj","acomp"}:
            object_1 = token
            break
            
    if object_1.text == None:
        for token in leftEnd_compare_tokens:
            # limitation
            #print(token)
            if len(part_pos)>1 and token.i < part_pos[-2]:
                continue

            if token.pos_ == "PUNCT":
                continue

            # judege whether it is object1

            if list(token.ancestors)[0].text == "is" and token.text!="so":
                object_1 = token
                break
    object_1,object_2=object_1.text, object_2.text
    
    # find the common property of object1 and object2
    # replace the item name to make sentence more correct
    if object_1:
        sent_text=sent_text.replace(object_1,"Jack")
    if object_2:
        sent_text=sent_text.replace(object_2,"Amy")
        
    doc=nlp(sent_text)
    sent=list(doc.sents)[0]
    
    # find root token idx for each sentence part
    part_pos=SentencePartDetection(sent)
    
    # two possible property
    property_root=doc[part_pos[0]]
    
    if property_root.pos_=="VERB":
        # find a cluster of token than close to root
        root_idx=property_root.i
        mindis_token=[token_text(""),float("inf")]
        for child in property_root.rights:
            idx=child.idx
            dif=idx-root_idx
            if child.text in {"Amy","Jack"}:
                continue
            if dif < mindis_token[1]:
                mindis_token=[child,dif]
        if mindis_token[0].text=="":
            most_right=root_idx
        else:
            most_right=find_max_idx(mindis_token[0],res=[])
        property_tokens=doc[root_idx:most_right+1].text.replace(" Amy "," ").replace(" Jack "," ")
    else:
        root_idx=property_root.i
        mindis_token=[token_text(""),float("inf")]
        for child in property_root.rights:
            idx=child.idx
            dif=idx-root_idx
            if child.text in {"Amy","Jack"}:
                continue
            if dif < mindis_token[1]:
                mindis_token=[child,dif]
        if mindis_token[0].text=="":
            most_right=root_idx
            most_left=root_idx
        else:
            most_right=find_max_idx(mindis_token[0],res=[])
            most_left=find_min_idx(mindis_token[0],res=[])
        property_tokens=doc[most_left:most_right+1].text.replace(" Amy "," ").replace(" Jack "," ")
    property_tokens=property_tokens.replace(" Amy's "," ").replace(" Jack's "," ")
    property_tokens=property_tokens.replace(" Amy","").replace(" Jack","")
    property_tokens=property_tokens.split("than")[0]

    sent=list(doc.sents)[-1]
    # find root token idx for each sentence part
    part_pos=SentencePartDetection(sent)
    
    # find compare aspect
    compare_root=doc[part_pos[-1]]
    aspect_token=subitem_depCheck(compare_root.rights, require={"acomp","attr"})
    
    if aspect_token.text:
        aspect_index=aspect_token.i
        for left in aspect_token.lefts:
            rel=left.dep_
            if rel not in ["advmod"] and left.i < aspect_index:
                aspect_index= left.i

        aspect_span= sent[aspect_index:]
        aspect_text=aspect_span.text.split(" than ")[0]
    else:
        aspect_span = sent[part_pos[-1]+1:]
        aspect_text=aspect_span.text.split(" than ")[0]
        
    # find ground truth is more or less:
    more=0
    less=0
    reverse= 1
    MoreOrLess=0
    for token in sent[part_pos[-1]:]:
        if token.text == "more":
            more +=1
            
        elif token.text == "less":
            less += 1
            
        elif token.text == "not" or token.text == "no":
            reverse=-1
        
    if "er " in sent[part_pos[-1]:].text:
        more += 1
    
    MoreOrLess = reverse*(more-less)>0
    Amy_pos=sent_text.find("Amy")
    Jack_pos=sent_text.find("Jack")
    if Amy_pos<Jack_pos:
        hyp_sent=f"Amy is {property_tokens}"
    else:
        hyp_sent=f"Jack is {property_tokens}"
    return str(object_1),str(object_2),aspect_text,property_tokens,MoreOrLess*"more"+(1-MoreOrLess)*"less",hyp_sent

## Analysis

In [19]:
# model setup
# load bert model
model = SentenceTransformer('nli-bert-large')
nlp = spacy.load("en_core_web_sm")

In [20]:
# load cskg_embed file
isfile_=os.path.isfile('cskg_model_embed.pickle')
if isfile_:
    with open('cskg_model_embed.pickle', 'rb') as handle:
        cskg_embed = pickle.load(handle)
else:
    # divide task into 1000chunks and run loop
    chunk_num = 1000
    chunks_size = math.ceil(len(cskg_sents)/1000)
    cskg_embed=np.empty((0,1024), dtype=np.float32)
    count = 0

    for num in tqdm(range(1,chunk_num+1)):
        end = num*chunks_size
        start = chunks_size*(num-1)
        temp_embed=model.encode(cskg_sents[start:end])
        cskg_embed=np.append(cskg_embed, temp_embed, axis=0)
        count += 1 
        
    # store file into desktop
    with open('cskg_model_embed.pickle', 'wb') as handle:
        pickle.dump(cskg_embed, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
cskg_embed=np.array([S/(math.sqrt(sum(S**2))) for S in tqdm(cskg_embed)])
cskg_embed[0]

100%|██████████| 6003237/6003237 [23:40<00:00, 4225.25it/s]


array([-0.00286278,  0.01514893,  0.02432977, ..., -0.00663066,
       -0.06165611, -0.05097185], dtype=float32)

In [28]:
with open(sample_1k_lines, "r") as f:
    sample_lines=[]
    for line in f:
        sample_lines.append(line.strip())

In [51]:
# collect a list of all possible labels
with open(ground_truth_file, "r") as f:
    ground_truth=set()
    for line in f:
        ground_truth.add(line.strip())
ground_truth

{'after', 'before', 'better', 'easier', 'harder', 'less', 'more', 'worse'}

In [59]:
ground_truth_table=[["before","after"],["easier","harder"],["more","less"],["better","worse"]]

In [107]:
# Find out special format for information extraction
# example
regex_format=".*is.*,.*is.*, so.*is.*than,*"

hyp_sents=[]
for sent in sample_lines:
    status=re.search(regex_format,sent)
    
    if not status:
        obj1,obj2,compare_aspect, property_,truth,property_descrip=elements_extraction_other(sent)
        # hyp sentences
        hyp_sents.append([f"{property_.strip()} causes {compare_aspect}"])
    else:
        obj1,obj2, obj1_property, obj2_property, aspect, truth=elements_extraction_isa(sent)
        C1= obj1_property +" is " + aspect
        C2= obj2_property +" is " + aspect
        hyp_sents.append([C1,C2])

In [108]:
map_hyp_sents=[]

for line in hyp_sents:
    for sent in line:
        map_hyp_sents.append(sent)
        
hyp_sents_embed=model.encode(map_hyp_sents)
hyp_sents_embed=np.array([S/(math.sqrt(sum(S**2))) for S in tqdm(hyp_sents_embed)])

# use faiss to find neareast
d= cskg_embed.shape[1]
index = faiss.IndexFlatL2(d)
index.add(cskg_embed)

min_sim=1
max_sim=0
# find the closest edges
k = 1
D, I = index.search(hyp_sents_embed, k)
for idx,embed in zip(I,hyp_sents_embed):
    idx=idx[0]
    embed1=cskg_embed[idx]
    embed2=embed
    similar=dot(embed1, embed2)/(norm(embed1)*norm(embed2))
    min_sim=min(min_sim,similar)
    max_sim=max(max_sim,similar)

100%|██████████| 1600/1600 [00:00<00:00, 4770.64it/s]


In [109]:
threshold=(min_sim+max_sim)/2
threshold

0.8234521150588989

In [110]:
doc=nlp('ajol is smaller than wzuexmkld, so wzuexmkld is harder to put into a box than ajol')
for token in doc:
    print(pos_tag([token.text]))

[('ajol', 'NN')]
[('is', 'VBZ')]
[('smaller', 'JJR')]
[('than', 'IN')]
[('wzuexmkld', 'NN')]
[(',', ',')]
[('so', 'RB')]
[('wzuexmkld', 'NN')]
[('is', 'VBZ')]
[('harder', 'NN')]
[('to', 'TO')]
[('put', 'NN')]
[('into', 'IN')]
[('a', 'DT')]
[('box', 'NN')]
[('than', 'IN')]
[('ajol', 'NN')]


In [117]:
prediction_truth=[]
sent_idx=0

for i in range(len(sample_lines)):
    sent=sample_lines[i]
    status=re.search(regex_format,sent)

    if not status:
        obj1,obj2,compare_aspect, property_,truth,property_descrip=elements_extraction_other(sent)
        idx=I[sent_idx][0]
        embed1=cskg_embed[idx]
        embed2=hyp_sents_embed[sent_idx]
        similar=dot(embed1, embed2)/(norm(embed1)*norm(embed2))
        property_item=property_descrip.split(" ")[0]

        if property_item== "Jack":
            for more,less in ground_truth_table:
                if more in sent or less in sent:
                    candiates=[less,more]
                    break

        else:
            for more,less in ground_truth_table:
                if more in sent or less in sent:
                    candiates=[more,less]
                    break

        predict_result=candiates[similar>threshold]
        
        prediction_truth.append(predict_result)
        sent_idx+=1
    else:
        obj1,obj2, obj1_property, obj2_property, aspect, truth=elements_extraction_isa(sent)
        idx1=I[sent_idx][0]
        idx2=I[sent_idx+1][0]
        edge1_embed=cskg_embed[idx1]
        edge2_embed=cskg_embed[idx2]
        sent_embed1=hyp_sents_embed[sent_idx]
        sent_embed2=hyp_sents_embed[sent_idx+1]

        # claculate similarity
        similar1=dot(edge1_embed, sent_embed1)/(norm(edge1_embed)*norm(sent_embed1))
        similar2=dot(edge2_embed, sent_embed2)/(norm(edge2_embed)*norm(sent_embed2))
        
        for more,less in ground_truth_table:
            if more in sent or less in sent:
                candiates=[less,more]
                break
        if similar1 > similar2:
            predict_result=more
        else:
            predict_result=less

        prediction_truth.append(predict_result)
        sent_idx+=2



In [118]:
with open(temp_result, "w") as f:
    for predict in prediction_truth:
        f.write(predict+"\n")

In [119]:
# find accuracy
with open(ground_truth_file, "r") as f:
    truth=[]
    for line in f:
        truth.append(line.strip())

accuracy=0
for ground, predict in zip(truth,prediction_truth):
    if ground ==predict:
        accuracy+=1
    
accuracy=accuracy/len(truth)
accuracy

0.49625

In [116]:
doc=nlp("hcwoctv is glass and mhzg is stone, so hcwoctv is not better at blocking light than mhzg")
spacy.displacy.render(doc, style="dep")