In [1]:
import numpy as np
import pandas as pd

In [2]:
from allennlp.predictors.predictor import Predictor
predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/bert-base-srl-2020.03.24.tar.gz")

In [3]:
import spacy
from spacy.tokens import Token
Token.set_extension('srl_arg0', default=None)
Token.set_extension('srl_arg1', default=None)

In [4]:
nlp = spacy.load("en")

In [5]:
def srl(sent):
    doc = nlp(sent)
    words = [token.text for token in doc]
    for i, word in enumerate(doc):
        if word.pos_ == "VERB":
            verb = word.text
            verb_labels = [0 for _ in words]
            verb_labels[i] = 1
            instance = predictor._dataset_reader.text_to_instance(doc, verb_labels)
            output = predictor._model.forward_on_instance(instance)
            tags = output['tags']
    
            if "B-ARG0" in tags:
                start = tags.index("B-ARG0")
                end = max([i for i, x in enumerate(tags) if x == "I-ARG0"] + [start]) + 1
                word._.set("srl_arg0", doc[start:end])
    
            if "B-ARG1" in tags:
                start = tags.index("B-ARG1")
                end = max([i for i, x in enumerate(tags) if x == "I-ARG1"] + [start]) + 1
                word._.set("srl_arg1", doc[start:end])
    res = {}
    for w in doc:
        if w.pos_ == "VERB":
            # print("ARG0:", w._.srl_arg0)
            # print("VERB:", w)
            # print("ARG1:", w._.srl_arg1)
            # print("-----------------")
            res[w] = {'ARG0': w._.srl_arg0, 'ARG1': w._.srl_arg1}
    return res

In [6]:
sent1 = "Trump thinks the violence in media will harm chidren"
sent2 = "Uriah honestly thinks the movie will harm kids"

In [7]:
srl(sent1)

{thinks: {'ARG0': Trump, 'ARG1': the violence in media will harm chidren},
 will: {'ARG0': the violence in media, 'ARG1': harm chidren},
 harm: {'ARG0': the violence in media, 'ARG1': chidren}}

In [8]:
srl(sent2)

{thinks: {'ARG0': Uriah, 'ARG1': the movie will harm kids},
 will: {'ARG0': the movie, 'ARG1': harm kids},
 harm: {'ARG0': the movie, 'ARG1': kids}}

In [102]:
def get_w2v_features(sent):
    MAX_LEN = 100
    sent = [w for w in sent.split()]
    w2v_embedding = np.zeros((MAX_LEN, 96))
    length = len(sent)
    w2v_embedding[:length] = np.array([nlp(w).vector for w in sent])
    return w2v_embedding

In [10]:
emb1 = get_w2v_features(sent1)
emb2 = get_w2v_features(sent2)

In [103]:
def cosine_similarity(vec1, vec2):
    dot = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return dot/(norm1*norm2)

In [104]:
# no stop words
stop_words = set()

def get_wms(doc1, doc2, emb1, emb2):
    # we treat doc1 as the query document
    result1 = [(word, emb) for word, emb in zip(doc1,emb1) if (word not in stop_words) and (not np.all(emb==0))]
    document1 = [x[0] for x in result1]
    embedding1 = [x[1] for x in result1]
    
    result2 = [(word, emb) for word, emb in zip(doc2,emb2) if (word not in stop_words) and (not np.all(emb==0))]
    document2 = [x[0] for x in result2]
    embedding2 = [x[1] for x in result2]
    
    s1 = set(document1)
    s2 = set(document2)
    
    d2 = {}
    for s in s2:
        d2[s] = document2.count(s)/len(document2) # words frequency
    
    sent_sim = []
    for i2, e2 in enumerate(embedding2):
        word_sim = []
        for e1 in embedding1:
            word_sim.append(cosine_similarity(e1, e2))

        sent_sim.append(max(word_sim)*d2[document2[i2]])
    return sum(sent_sim)

In [13]:
get_wms(sent1, sent2, emb1, emb2)

# similarity using srl should be the same

0.9852413502973426

In [14]:
verb_threshold = 0.5

def get_wms_srl_old(sent1, sent2):
    # we treat doc1 as the query document
    # for same verb, calculate the word mover's distance between corresponding args in two sentences
    # can be used as a supplementary to our dependency parsing based approach
    # currently using sentence inputs, for word list inputs, join the words together to perform srl
    parsing_result1 = srl(sent1)
    parsing_result2 = srl(sent2)
    
    doc1 = [w for w in sent1.split()]
    doc2 = [w for w in sent2.split()]
    
    emb1 = get_w2v_features(sent1)
    emb2 = get_w2v_features(sent2)
    
    # similarity is 0 if no verb is same
    sent_sim = [0]
    
    for v1 in parsing_result1:
        for v2 in parsing_result2:
            # check verb similarity first, if greater than threshold, calculate corresponding arg0 and arg1 similarity
            # using cosine similarity between word2vec features
            verb_sim = cosine_similarity(nlp(v1.text).vector, nlp(v2.text).vector)
            if verb_sim < verb_threshold:
                continue
            
            # check if verbs are the same
            # if v1.text != v2.text:
            #     continue
            
            # calculate similarity in arg0
            if parsing_result1[v1]['ARG0'] is None or parsing_result2[v2]['ARG0'] is None:
                sim_arg0 = 0
            else:
                arg0_1 = parsing_result1[v1]['ARG0'].text
                arg0_2 = parsing_result2[v2]['ARG0'].text
                arg0_1 = [w for w in arg0_1.split()]
                arg0_2 = [w for w in arg0_2.split()]
                arg0_1_emb = [emb1[doc1.index(w)] for w in arg0_1]
                arg0_2_emb = [emb2[doc2.index(w)] for w in arg0_2]
                s1 = set(arg0_1)
                s2 = set(arg0_2)
                d2 = {}
                for s in s2:
                    # d2[s] = arg0_2.count(s)/len(arg0_2)
                    d2[s] = arg0_2.count(s)/len(doc2)
                arg0_sim_list = []
                for i2, e2 in enumerate(arg0_2_emb):
                    word_sim = []
                    for e1 in arg0_1_emb:
                        word_sim.append(cosine_similarity(e1, e2))
                    arg0_sim_list.append(max(word_sim)*d2[arg0_2[i2]])
                sim_arg0 = sum(arg0_sim_list)
                print('ARG0 similarity for verb {}: {}'.format(v1.text, sim_arg0))
            
            # calculate similarity in arg1
            if parsing_result1[v1]['ARG1'] is None or parsing_result2[v2]['ARG1'] is None:
                sim_arg1 = 0
            else:
                arg1_1 = parsing_result1[v1]['ARG1'].text
                arg1_2 = parsing_result2[v2]['ARG1'].text
                arg1_1 = [w for w in arg1_1.split()]
                arg1_2 = [w for w in arg1_2.split()]
                arg1_1_emb = [emb1[doc1.index(w)] for w in arg1_1]
                arg1_2_emb = [emb2[doc2.index(w)] for w in arg1_2]
                s1 = set(arg1_1)
                s2 = set(arg1_2)
                d2 = {}
                for s in s2:
                    # d2[s] = arg1_2.count(s)/len(arg1_2)
                    d2[s] = arg1_2.count(s)/len(doc2)
                arg1_sim_list = []
                for i2, e2 in enumerate(arg1_2_emb):
                    word_sim = []
                    for e1 in arg1_1_emb:
                        word_sim.append(cosine_similarity(e1, e2))
                    arg1_sim_list.append(max(word_sim)*d2[arg1_2[i2]])
                sim_arg1 = sum(arg1_sim_list)
                print('ARG1 similarity for verb {}: {}'.format(v1.text, sim_arg1))
            sent_sim.append(sim_arg0 + sim_arg1 + verb_sim * doc2.count(v2.text) / len(doc2))
    return sum(sent_sim)

In [15]:
get_wms_srl_old(sent1, sent2)

ARG0 similarity for verb thinks: 0.05743413319981495
ARG1 similarity for verb thinks: 0.5346373936310943
ARG0 similarity for verb will: 0.20483168299083568
ARG1 similarity for verb will: 0.20480571064025865
ARG0 similarity for verb harm: 0.20483168299083568
ARG1 similarity for verb harm: 0.07980571064025864


1.6613463215436783

In [123]:
verb_threshold = 0.5

def get_all_triples(sent):
    parsing_result = srl(sent)
    res = []
    for verb, args in parsing_result.items():
        res.append((args['ARG0'], verb, args['ARG1']))
    return res

def get_first_level_triples(triples):
    if triples == []:
        return []
    
    res = []
    for i, triple in enumerate(triples):
        verb = triple[1]
        first_level_verb_flag = True
        for j, other_triple in enumerate(triples):
            if i == j:
                continue
            other_triple_arg0 = other_triple[0].text if other_triple[0] is not None else ''
            other_triple_arg1 = other_triple[2].text if other_triple[2] is not None else ''
            
            if verb.text in other_triple_arg0 or verb.text in other_triple_arg1:
                # print('{} not in subject: {} and object: {}'.format(verb.text, other_triple[0].text, other_triple[2].text))
                first_level_verb_flag = False
        if first_level_verb_flag == True:
            res.append(triple)
    return res

def get_child_triples(triple, triple_list):
    res = []
    for other_triple in triple_list:
        verb = other_triple[1]
        triple_arg0 = triple[0].text if triple[0] is not None else ''
        triple_arg1 = triple[2].text if triple[2] is not None else ''
        if verb.text in triple_arg0 or verb.text in triple_arg1:
            res.append(other_triple)
    return res

def remove_triple(triple, triple_list, matched):
    if matched == False:
        # for unmatched triple, just remove the triple
        triple_list.remove(triple)
    else:
        # for matched triple, remove the triple and all child triples
        child_triples = get_child_triples(triple, triple_list)
        for child_triple in child_triples:
            triple_list.remove(child_triple)
        triple_list.remove(triple)
        
def get_wms_srl_recursive(sent1, sent2):
    # print(sent1)
    # print(sent2)
    sent1 = sent1.replace(',', ' ')
    sent1 = sent1.replace('.', ' ')
    sent1 = sent1.replace('"', ' ')
    sent1 = sent1.replace("'", ' ')
    sent1 = sent1.replace("’", ' ')
    sent1 = sent1.lower()
    sent2 = sent2.replace(',', ' ')
    sent2 = sent2.replace('.', ' ')
    sent2 = sent1.replace('"', ' ')
    sent2 = sent1.replace("'", ' ')
    sent2 = sent2.replace("’", ' ')
    sent2 = sent2.lower()
    # we treat sent1 as the query sentence and sent2 as the candidate sentence
    all_triples1 = get_all_triples(sent1)
    all_triples2 = get_all_triples(sent2)
    
    first_level_triples1 = get_first_level_triples(all_triples1)
    first_level_triples2 = get_first_level_triples(all_triples2)
    
    doc1 = [w for w in sent1.split()]
    doc2 = [w for w in sent2.split()]
    
    emb1 = get_w2v_features(sent1)
    emb2 = get_w2v_features(sent2)
    
    # similarity is 0 if no pair of verbs is similar
    sim_sent = 0
    
    if all_triples1 == []:
        # print('SRL got no result on sentence: {}'.format(sent1))
        return 0
    elif all_triples2 == []:
        # print('SRL got no result on sentence: {}'.format(sent2))
        return 0
    
    while all_triples2:
        for triple2 in first_level_triples2:
            # find the the most similar verb in sent1
            v2 = triple2[1].text
            max_sim_verb = 0
            nearest_triple1 = None
            for triple1 in first_level_triples1:
                v1 = triple1[1].text
                sim_verb = cosine_similarity(nlp(v2).vector, nlp(v1).vector)
                if sim_verb >= max_sim_verb:
                    max_sim_verb = sim_verb
                    nearest_triple1 = triple1
            
            # print('The most similar verb for {} is {}, similarity: {}.'.format(v2, nearest_triple1[1].text, max_sim_verb))
            
            if nearest_triple1 is None:
                sim_triple = 0
                remove_triple(triple2, all_triples2, False)
                
            # compute triple similarity
            elif max_sim_verb < verb_threshold:
                # unmatched triples
                # print('Verb similarity smaller than threshold: {}, triple similarity is zero.'.format(verb_threshold))
                sim_triple = 0
                
                # remove the triples
                remove_triple(nearest_triple1, all_triples1, False)
                remove_triple(triple2, all_triples2, False)
                
            else:
                # matched triples
                sim_verb = cosine_similarity(nlp(triple2[1].text).vector, nlp(nearest_triple1[1].text).vector)
                
                # calculate the similarity in ARG0
                if triple2[0] is None or nearest_triple1[0] is None:
                    sim_arg0 = 0
                else:
                    arg0_1 = nearest_triple1[0].text
                    arg0_2 = triple2[0].text
                    arg0_1 = [w for w in arg0_1.split()]
                    arg0_2 = [w for w in arg0_2.split()]
                    arg0_1_emb = [emb1[doc1.index(w)] for w in arg0_1]
                    arg0_2_emb = [emb2[doc2.index(w)] for w in arg0_2]
                    s1 = set(arg0_1)
                    s2 = set(arg0_2)
                    d2 = {}
                    for s in s2:
                        d2[s] = arg0_2.count(s)/len(arg0_2)
                    arg0_sim_list = []
                    for i2, e2 in enumerate(arg0_2_emb):
                        word_sim = []
                        for e1 in arg0_1_emb:
                            word_sim.append(cosine_similarity(e1, e2))
                        arg0_sim_list.append(max(word_sim)*d2[arg0_2[i2]])
                    sim_arg0 = sum(arg0_sim_list)
                # print('ARG0 similarity for verb {}: {}'.format(v2, sim_arg0))
                
                # calculate similarity in ARG1
                if triple2[2] is None or nearest_triple1[2] is None:
                    sim_arg1 = 0
                else:
                    arg1_1 = nearest_triple1[2].text
                    arg1_2 = triple2[2].text
                    arg1_1 = [w for w in arg1_1.split()]
                    arg1_2 = [w for w in arg1_2.split()]
                    arg1_1_emb = [emb1[doc1.index(w)] for w in arg1_1]
                    arg1_2_emb = [emb2[doc2.index(w)] for w in arg1_2]
                    s1 = set(arg1_1)
                    s2 = set(arg1_2)
                    d2 = {}
                    for s in s2:
                        d2[s] = arg1_2.count(s)/len(arg1_2)
                    arg1_sim_list = []
                    for i2, e2 in enumerate(arg1_2_emb):
                        word_sim = []
                        for e1 in arg1_1_emb:
                            word_sim.append(cosine_similarity(e1, e2))
                        arg1_sim_list.append(max(word_sim)*d2[arg1_2[i2]])
                    sim_arg1 = sum(arg1_sim_list)
                # print('ARG1 similarity for verb {}: {}'.format(v2, sim_arg1))
                
                sim_triple = sim_verb + sim_arg0 + sim_arg1
                
                # remove the triples and child triples
                remove_triple(nearest_triple1, all_triples1, True)
                remove_triple(triple2, all_triples2, True)
            
            sim_sent += sim_triple
            first_level_triples1 = get_first_level_triples(all_triples1)
            first_level_triples2 = get_first_level_triples(all_triples2)
    
    return sim_sent

In [86]:
get_wms_srl_recursive(sent1, sent2)

2.0509519470103115

In [87]:
df1 = pd.read_csv('./Li.csv')

In [88]:
df1

Unnamed: 0,word1,word2,sent1,sent2,human_sim
0,cord,smile,"Cord is strong, thick string.",A smile is the expression that you have on you...,0.0100
1,rooster,voyage,A rooster is an adult male chicken.,A voyage is a long journey on a ship or in a s...,0.0050
2,noon,string,Noon is 12 o’clock in the middle of the day.,"String is thin rope made of twisted threads, u...",0.0125
3,fruit,furnace,Fruit or a fruit is something which grows on a...,A furnace is a container or enclosed space in ...,0.0475
4,autograph,shore,An autograph is the signature of someone famou...,"The shores or shore of a sea, lake, or wide ri...",0.0050
...,...,...,...,...,...
59,cushion,pillow,A cushion is a fabric case filled with soft ma...,A pillow is a rectangular cushion which you re...,0.5225
60,cemetery,graveyard,A cemetery is a place where dead people’s bodi...,"A graveyard is an area of land, sometimes near...",0.7725
61,automobile,car,An automobile is a car.,A car is a motor vehicle with room for a small...,0.5575
62,midday,noon,Midday is 12 o’clock in the middle of the day.,Noon is 12 o’clock in the middle of the day.,0.9550


In [89]:
sentences1 = df1['sent1']
sentences2 = df1['sent2']

In [93]:
srl_sim = list(map(get_wms_srl_recursive, sentences1, sentences2))

In [94]:
df1['srl_sim'] = srl_sim

In [95]:
df1

Unnamed: 0,word1,word2,sent1,sent2,human_sim,srl_sim
0,cord,smile,"Cord is strong, thick string.",A smile is the expression that you have on you...,0.0100,0.000000
1,rooster,voyage,A rooster is an adult male chicken.,A voyage is a long journey on a ship or in a s...,0.0050,0.000000
2,noon,string,Noon is 12 o’clock in the middle of the day.,"String is thin rope made of twisted threads, u...",0.0125,0.000000
3,fruit,furnace,Fruit or a fruit is something which grows on a...,A furnace is a container or enclosed space in ...,0.0475,2.842466
4,autograph,shore,An autograph is the signature of someone famou...,"The shores or shore of a sea, lake, or wide ri...",0.0050,0.000000
...,...,...,...,...,...,...
59,cushion,pillow,A cushion is a fabric case filled with soft ma...,A pillow is a rectangular cushion which you re...,0.5225,1.055305
60,cemetery,graveyard,A cemetery is a place where dead people’s bodi...,"A graveyard is an area of land, sometimes near...",0.7725,1.882579
61,automobile,car,An automobile is a car.,A car is a motor vehicle with room for a small...,0.5575,0.000000
62,midday,noon,Midday is 12 o’clock in the middle of the day.,Noon is 12 o’clock in the middle of the day.,0.9550,0.000000


In [96]:
df1.to_csv('./Li_srl.csv')

In [106]:
df2 = pd.read_csv('./query_1_sentences_random.csv')

In [107]:
df2

Unnamed: 0,sentence1,sentence2,whole_sim,subject_sim,action_sim,obj_sim
0,He has generally supported the existing outloo...,"""As we approach next week's FOMC day, we shoul...",0,0,0,0.0
1,He has generally supported the existing outloo...,Treasury yields shrank further on Wednesday af...,3,0,0,4.0
2,He has generally supported the existing outloo...,TD Securities: Bearish view on the US dollar 1...,0,0,0,0.0
3,He has generally supported the existing outloo...,The U.S. economy should grow at a good clip th...,0,0,0,1.0
4,"""It shows how sensitive the markets are around...","""It seems like the market is positioned for so...",3,4,5,4.0
5,"""It shows how sensitive the markets are around...","""I would expect the recent correction in equit...",2,4,5,0.0
6,"""It shows how sensitive the markets are around...","On a technical level, people have been watchin...",2,4,5,2.0
7,"""It shows how sensitive the markets are around...","The dollar dipped briefly, then rose as invest...",0,0,0,1.0
8,Though rising long-term interest rates and rec...,"""For the moment it looks like gold appears int...",2,3,5,0.0
9,Though rising long-term interest rates and rec...,"""It's not good or bad. It's a surprise and mar...",1,3,5,1.0


In [108]:
sentences1 = df2['sentence1']
sentences2 = df2['sentence2']

In [124]:
srl_sim = list(map(get_wms_srl_recursive, sentences1, sentences2))

In [125]:
df2['srl_sim'] = srl_sim

In [126]:
df2

Unnamed: 0,sentence1,sentence2,whole_sim,subject_sim,action_sim,obj_sim,srl_sim
0,He has generally supported the existing outloo...,"""As we approach next week's FOMC day, we shoul...",0,0,0,0.0,10.487179
1,He has generally supported the existing outloo...,Treasury yields shrank further on Wednesday af...,3,0,0,4.0,10.487179
2,He has generally supported the existing outloo...,TD Securities: Bearish view on the US dollar 1...,0,0,0,0.0,10.487179
3,He has generally supported the existing outloo...,The U.S. economy should grow at a good clip th...,0,0,0,1.0,10.487179
4,"""It shows how sensitive the markets are around...","""It seems like the market is positioned for so...",3,4,5,4.0,4.142857
5,"""It shows how sensitive the markets are around...","""I would expect the recent correction in equit...",2,4,5,0.0,4.142857
6,"""It shows how sensitive the markets are around...","On a technical level, people have been watchin...",2,4,5,2.0,4.142857
7,"""It shows how sensitive the markets are around...","The dollar dipped briefly, then rose as invest...",0,0,0,1.0,4.142857
8,Though rising long-term interest rates and rec...,"""For the moment it looks like gold appears int...",2,3,5,0.0,6.1
9,Though rising long-term interest rates and rec...,"""It's not good or bad. It's a surprise and mar...",1,3,5,1.0,6.1


In [127]:
df2.to_csv('./query_1_sentences_random_srl.csv')