In [49]:
import os, spacy, gzip, rltk, gzip, random, os
import torch
from tqdm import tqdm
from collections import defaultdict
import networkx as nx
from numpy import dot
from numpy.linalg import norm
import numpy as np
from sentence_transformers import SentenceTransformer
from textblob import TextBlob
from nltk.sentiment import SentimentIntensityAnalyzer
import matplotlib.pyplot as plt
from nltk.corpus import wordnet as wn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import json

In [23]:
input_dir='input'
wordnet_file="%s/kgtk_wordnet.tsv" % input_dir
cskg_embeddings_file="%s/bert_embeddings.txt" % input_dir
cskg_tranE_file="%s/trans_log_dot_0.1.tsv.gz" % input_dir
cskg_complex_file='%s/comp_log_dot_0.1.tsv.gz' % input_dir
cskg_file="%s/cskg_renamed.tsv.gz" % input_dir

## Data Preparation
load data model, embedding file, and black list (takes 10 mins, grab a coffee)

In [4]:
import nltk
nltk.download('vader_lexicon')

[nltk_data] Downloading package vader_lexicon to
[nltk_data]     /Users/filipilievski/nltk_data...


True

In [5]:
# sentence transformer model
model = SentenceTransformer('nli-bert-large')
# nlp model
nlp = spacy.load("en_core_web_sm")
# sentiment model
sia = SentimentIntensityAnalyzer()

Some weights of the model checkpoint at /Users/filipilievski/.cache/torch/sentence_transformers/sbert.net_models_nli-bert-large/0_BERT were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
def text2token(text):
    doc=nlp(text)
    for sent in doc.sents:
        break
        
    return list(sent)

def load_source(filename):
    # load data
    with open(filename,"r",encoding="utf-8") as f:
        head=f.readline().strip().split("\t")
        data=[]
        for line in f:
            temp=line.strip().split("\t")
            data.append([item.split("|")[0] for item in temp])
            
    return head, data

def cosine_similar(embed1, embed2):
    similar=dot(embed1, embed2)/(norm(embed1)*norm(embed2))
    return similar

# load cskg embedding file
def load_embedding():
    with open(cskg_embeddings_file,"r") as f:
        head= f.readline().strip().split("\t")

        # obtain embedding_sentence in file
        cskg_word_embeddings=dict()

        for item in tqdm(f):
            # obtain list of line
            line=item.strip().split("\t")

            # only property is text embedding can obtain embeddings
            word=line[0]
            prop=line[1]
            embedding=line[2]

            if prop=="text_embedding":
                cskg_word_embeddings[word]=embedding
    return cskg_word_embeddings

# load complex and tranE file
def load_embedding_gz(filename):
    cskg_word_embeddings=dict()
    f=gzip.open(filename,'rb')
    
    for item in tqdm(f):
        # obtain list of line
        line=item.strip().decode("utf-8").split("\t")
        word=line[0]
        embed=np.array([float(_) for _ in line[2:]])
        cskg_word_embeddings[word]=embed
        
    return cskg_word_embeddings

def synsets_cosine_sim(label1, label2):
    # find the highest cosine similarity between two label by synsets definition
    syns1=wn.synsets(label1.replace(" ","_"))
    syns2=wn.synsets(label2.replace(" ","_"))
    
    sents1=[]
    sents2=[]
    
    for syn in syns1:
        sent=syn.definition()
        sents1.append(sent)
        
    for syn in syns2:
        sent=syn.definition()
        sents2.append(sent)
        
    sents1_embed=model.encode(sents1)
    sents2_embed=model.encode(sents2)
    max_sim=0
    
    for embed1 in sents1_embed:
        for embed2 in sents2_embed:
            sim=cosine_similar(embed1,embed2)
            max_sim=max(max_sim,sim)
            
    return max_sim

def find_sent(sent_target, items):
    for item in items:
        sent_candit=item[0]
        
        if sent_candit==sent_target:
            return item
        
def prediction_check(predict_score, threshold, ground):
    if predict_score>threshold:
        predict="1"
    else:
        predict="0"
        
    return predict==ground

In [7]:
# embedding file
cskg_tranE_embedding=load_embedding_gz(cskg_tranE_file)
cskg_complex_embedding=load_embedding_gz(cskg_complex_file)

2160968it [01:15, 28493.80it/s]
2160968it [01:21, 26573.84it/s]


In [24]:
cskg_word_embeddings=load_embedding()

4322096it [00:42, 101251.85it/s]


In [10]:
# build wordnet ontology
# load wordnet
# use wordnet build ontology to find blacklist related to person and animal
wordnet_head,wordnet_lines=load_source(wordnet_file)
wordnet_g=nx.DiGraph()

id2label=defaultdict(set)
#black_list={"room.n.02",'room.n.01',"food.n.01",'animal.n.01','person.n.01',"people.n.01","peoples.n.01"}
black_list={'animal.n.01','person.n.01',"people.n.01","peoples.n.01"}
for line in wordnet_lines:
    if line[1]== "/r/IsA":
        node1=line[0].split(":")[1]
        node2=line[2].split(":")[1]
        node1_labels=line[3].replace('"',"").split("|")
        node2_labels=line[4].replace('"',"").split("|")
        id2label[node1]=id2label[node1].union(node1_labels)
        id2label[node2]=id2label[node2].union(node2_labels)
        
        if ".v." in node1 or ".v." in node1:
            continue
        wordnet_g.add_edge(node1,node2)
    else:
        continue
        
new_black_list=set()
# check all sub class of blacklist

for item in black_list:
    temp=[item]
    new_black_list.add(item)
    while temp:
        new_temp=[]
        for issue in temp:
            new_temp+=[edge[0] for edge in wordnet_g.in_edges(issue)]
            
        new_temp_set=set(new_temp)
        new_black_list=new_black_list.union(new_temp_set)
        temp=new_temp_set
        
# check the blacklist label
blacklist_label=set()

for id_ in new_black_list:
    blacklist_label=blacklist_label.union(id2label[id_])
    
blacklist_label.add("human")

In [12]:
len(blacklist_label)

10219

In [15]:
len(new_black_list)

11117

In [21]:
# example of blacklist_label
# check whether girl in blacklist

"girl" in blacklist_label

True

In [18]:
# load cskg_renamed file
with gzip.open(cskg_file) as f:

    head=f.readline()
    cskg_renamed_dist=dict()
    for line in f:
        line=line.strip().decode("utf-8").split("\t")
        node1=line[1]
        node2=line[3]

        node1_label=line[4].split("|")[0].replace('"','')
        node2_label=line[5].split("|")[0].replace('"','')
        if node1_label != "":
            cskg_renamed_dist[node1]=node1_label

        if node2_label != "":
            cskg_renamed_dist[node2]=node2_label

## Case 2: Alternatives

I went to the (Y) and I wanted to (X). There was no (Z), can I use (C) instead?

In [88]:
# data file location

data_part1="output/case2_1.gz"
data_part2="output/case2_2.gz"

case2_output='output/case2.jsonl'

In [29]:
%%bash
kgtk query --debug -i input/cskg_renamed.tsv.gz --match '(x)<-[:`/r/CapableOf`]-(z)-[:`/r/AtLocation`]->(y)<-[:`/r/AtLocation`]-(c)-[:`/r/CapableOf`]->(x)' --where 'c!=z' --return 'z as plan, c as alt, x as goal, y as loc' -o output/case2_1.gz
kgtk query --debug -i input/cskg_renamed.tsv.gz --match '(x)<-[:`/r/UsedFor`]-(z)-[:`/r/AtLocation`]->(y)<-[:`/r/AtLocation`]-(c)-[:`/r/UsedFor`]->(x)' --where 'c!=z' --return 'z as plan, c as alt, x as goal, y as loc' -o output/case2_2.gz

[2021-05-26 09:49:07 sqlstore]: IMPORT graph directly into table graph_1 from /Users/filipilievski/mcs/kg-bert/GameSentence/input/cskg_renamed.tsv.gz ...
[2021-05-26 09:50:16 query]: SQL Translation:
---------------------------------------------
  SELECT graph_1_c1."node1" "_aLias.plan", graph_1_c4."node1" "_aLias.alt", graph_1_c1."node2" "_aLias.goal", graph_1_c3."node2" "_aLias.loc"
     FROM graph_1 AS graph_1_c1, graph_1 AS graph_1_c2, graph_1 AS graph_1_c3, graph_1 AS graph_1_c4
     WHERE graph_1_c1."label"=?
     AND graph_1_c2."label"=?
     AND graph_1_c3."label"=?
     AND graph_1_c4."label"=?
     AND graph_1_c1."node1"=graph_1_c2."node1"
     AND graph_1_c1."node2"=graph_1_c4."node2"
     AND graph_1_c2."node2"=graph_1_c3."node2"
     AND graph_1_c3."node1"=graph_1_c4."node1"
     AND (graph_1_c4."node1" != graph_1_c1."node1")
  PARAS: ['/r/CapableOf', '/r/AtLocation', '/r/AtLocation', '/r/CapableOf']
---------------------------------------------
[2021-05-26 09:50:16 sqlsto

In [33]:
# load case2.txt file (the data file generated by kgtk)
# output head and content
f=gzip.open(data_part1)

head=f.readline()
content=[]
for line in f:
    content.append(line.strip().decode("utf-8").split("\t"))
    
f=gzip.open(data_part2)

head=f.readline()
for line in f:
    content.append(line.strip().decode("utf-8").split("\t"))

In [34]:
# remove duplicate
# content line format: [z, c, x, y]
# for case2, [z,c,x,y] is the same as [c,z,x,y]
content_set=set()

for line in content:
    z,c,x,y=line
    
    if (z,c,x,y) not in content_set and (c,z,x,y) not in content_set:
        content_set.add((z,c,x,y))

In [35]:
# before removing
print(f"Number of lines: {len(content)}")
# after removing
print(f"Number of lines after removing duplicate: {len(content_set)}")

Number of lines: 29926
Number of lines after removing duplicate: 14949


In [36]:
# from content obtain the frequency of each location
loc_distribution=dict()
for line in content_set:
    location=line[3]
    
    loc_distribution[location]=loc_distribution.get(location,0)+1
print("Number of location type:", len(loc_distribution))

# only choose top 100 locations
num_loc=100
loc_sort=sorted(loc_distribution.items(),key=lambda k:k[1], reverse=True)
loc_chosen=set([_[0] for _ in loc_sort[:num_loc]])
# manual remove 
manual_remove_list=["/c/en/symphony","/c/en/marching_band","/c/en/band","/c/en/orchestra"]
for item in manual_remove_list:
    loc_chosen.remove(item)

print("Number of chosen location:", len(loc_chosen))
# filter contenct by loc_chosen
new_content=[]

for line in content_set:
    location=line[3]
    if location in loc_chosen:
        new_content.append(line)

Number of location type: 687
Number of chosen location: 96


In [37]:
print(f"Number of lines for all locations: {len(content_set)}")
print(f"Number of lines for top 100 locations: {len(new_content)}")

Number of lines for all locations: 14949
Number of lines for top 100 locations: 9733


In [39]:
# filter content (target about 20 lines)
# I went to the (Y) and I wanted to (X). There was no (Z), can I use (C) instead?
# z as plan, c as alt, x as goal, y as loc

"""
Filter:
Find similarity between Z and C
filter sentence if leve similarity>0.6, jaccard similarity>0.6, and word synsets
"""
filter_content=[]
leve_threshold=0.6
jaccard_threshold=0.6
for i in tqdm(range(len(new_content))):
    line = new_content[i]
    z_label=cskg_renamed_dist[line[0]]
    c_label=cskg_renamed_dist[line[1]]
    x_label=cskg_renamed_dist[line[2]]
    
    leve_sim=rltk.levenshtein_similarity(z_label,c_label)
    jaccard_sim=rltk.hybrid_jaccard_similarity(set(z_label.split(" ")),set(c_label.split(" ")),
                                                   function=rltk.levenshtein_similarity)
    
    if leve_sim>=leve_threshold or jaccard_sim>=jaccard_threshold:
        continue
    
    # filter line by X, POS should be (verb)
    tokens=text2token(x_label)
    
    if len(tokens)==0:
        continue
        
    token1=tokens[0]
    z_token=text2token(line[0].split("/")[-1])[0]
    c_token=text2token(line[1].split("/")[-1])[0]
    y_label=line[-1].split("/")[-1]
    
    if z_token.lemma_==c_token.lemma_:
        continue
    
    # filter token like "cat", "dog", "woman".....
    if z_token.lemma_ in blacklist_label or c_token.lemma_ in blacklist_label:
        continue
    
    # filter "theater" and "movie theater"
    if z_token.lemma_ in c_token.lemma_ or c_token.lemma_ in z_token.lemma_:
        continue
    
#     z_embed=np.array(eval("["+cskg_word_embeddings[line[0]]+"]"))
#     c_embed=np.array(eval("["+cskg_word_embeddings[line[1]]+"]"))
#     cosine_sim=synsets_cosine_sim(z_label,c_label)
#     if cosine_sim>=0.9:
#         continue
        
    if token1.pos_=="VERB":
        filter_content.append(line)

100%|██████████| 9733/9733 [01:24<00:00, 115.57it/s]


In [40]:
print(f"Number of lines before filtering: {len(new_content)}")
print(f"Number of lines after filtering: {len(filter_content)}")

Number of lines before filtering: 9733
Number of lines after filtering: 4643


In [87]:
# generate sentence based on the generation filter line
# Sentence format: I went to the (Y) and I wanted to (X). There was no (Z), can I use (C) instead?
location_sents=defaultdict(set)

"""
output: location_sents
output format:{location_id:[sentence,
                            Node Embedding Similarity,
                            transE Embedding similarity,
                            complex similarity,
                            Z label name,
                            C label name]}
"""
possible_items=set()
with open(case2_output, 'w') as w:
    for line in tqdm(filter_content):
        z_label=cskg_renamed_dist[line[0]]
        c_label=cskg_renamed_dist[line[1]]
        x_label=cskg_renamed_dist[line[2]]
        x_label = " ".join([_.lemma_ for _ in text2token(x_label)])
        y_label=cskg_renamed_dist[line[3]]
        possible_items.add(c_label)
        sent=f"I went to the {y_label} and I wanted to {x_label}. There was no {z_label}, can I use {c_label} instead?"
        w.write(json.dumps({'story': sent, 'answer': 'Y'}) + '\n')

100%|██████████| 4643/4643 [00:16<00:00, 280.95it/s]


In [99]:
print(random.sample(possible_items, 1)[0])
print(len(possible_items))

subway platform
693


In [101]:
# Find negative cases
import random
with open(case2_output, 'a') as w:
    for line in tqdm(filter_content):
        z_label=cskg_renamed_dist[line[0]]
        x_label=cskg_renamed_dist[line[2]]
        x_label = " ".join([_.lemma_ for _ in text2token(x_label)])
        y_label=cskg_renamed_dist[line[3]]
        c_label=random.sample(possible_items, 1)[0]
        if c_label!=y_label and c_label!=cskg_renamed_dist[line[1]]:
            sent=f"I went to the {y_label} and I wanted to {x_label}. There was no {z_label}, can I use {c_label} instead?"
            w.write(json.dumps({'story': sent, 'answer': 'N'}) + '\n')

100%|██████████| 4643/4643 [00:15<00:00, 298.74it/s]


## case 1: Unmet expectations
I went to the Y. There was a X but it had no Z. Am I disappointed?

In [52]:
# data file location

data_case1="output/case1.gz"

case1_output='output/case1.jsonl'

In [54]:
%%bash
kgtk query --debug -i input/cskg_renamed.tsv.gz --match '(z)-[:`/r/AtLocation`]->(x)-[:`/r/AtLocation`]->(y)' --return 'z as needed, x as place, y as Y' -o output/case1.gz

[2021-05-26 10:16:55 query]: SQL Translation:
---------------------------------------------
  SELECT graph_1_c1."node1" "_aLias.needed", graph_1_c1."node2" "_aLias.place", graph_1_c2."node2" "_aLias.Y"
     FROM graph_1 AS graph_1_c1, graph_1 AS graph_1_c2
     WHERE graph_1_c1."label"=?
     AND graph_1_c2."label"=?
     AND graph_1_c1."node2"=graph_1_c2."node1"
  PARAS: ['/r/AtLocation', '/r/AtLocation']
---------------------------------------------


In [55]:
# load kgtk file
with gzip.open(data_case1) as f:
    head_case1=f.readline()
    content_case1=[]
    for line in f:
        content_case1.append(line.strip().decode("utf-8").split("\t"))

In [56]:
# remove duplicate
# content line format: [z, x, y]
content_case1_set=set()

for line in content_case1:
    z,x,y=line
    
    if (z,x,y) not in content_case1_set and (z,x,y) not in content_case1_set:
        content_case1_set.add((z,x,y))

In [57]:
# before removing
print(f"Number of lines: {len(content_case1)}")
# after removing
print(f"Number of lines after removing duplicate: {len(content_case1_set)}")

Number of lines: 121148
Number of lines after removing duplicate: 121148


In [58]:
# from content obtain the frequency of each location
loc_distribution=dict()
for line in content_case1_set:
    location=line[2]
    
    loc_distribution[location]=loc_distribution.get(location,0)+1

print("Number of location type:", len(loc_distribution))

# only choose top 100 locations
num_loc=100
print("Number of chosen location:", num_loc)
loc_sort=sorted(loc_distribution.items(),key=lambda k:k[1], reverse=True)
loc_chosen=set([_[0] for _ in loc_sort[:num_loc]])

# filter contenct by loc_chosen
new_content_case1=[]

for line in tqdm(content_case1):
    location=line[2]
    if location in loc_chosen:
        new_content_case1.append(line)

100%|██████████| 121148/121148 [00:00<00:00, 2280130.05it/s]

Number of location type: 3064
Number of chosen location: 100





In [59]:
print(f"Number of lines for all locations: {len(content_case1_set)}")
print(f"Number of lines for top 100 locations: {len(new_content_case1)}")

Number of lines for all locations: 121148
Number of lines for top 100 locations: 49966


In [85]:
# I went to the {y_label}. There was a {x_label} but it had no {z_label}. Am I disappointed?
# filter by sentiment analysis and cosine similarity

"""
Output:
filter_content_case1: filtered x,y,z id, [[z1,x1,y1],[z2,x2,y2],[z3,x3,y3]] 
loc_line:{location_id:[sentence,
                        node bert embedding similarity,
                        tranE Embedding similarity,
                        complex embedding similarity,
                        z label name,
                        x label name]}
"""
filter_content_case1=[]
loc_line=defaultdict(set)

with open(data_case1, 'w') as w:
    for line in tqdm(new_content_case1):
        pos_sent=""
        neg_sent=""
        z_label=cskg_renamed_dist[line[0]]
        x_label=cskg_renamed_dist[line[1]]
        y_label=cskg_renamed_dist[line[2]]
        if z_label==x_label:
            continue
        x_tokens=text2token(x_label)
        tokens=text2token(z_label)
        sent=f"I went to the {y_label}. There was a {x_label} but it had no {z_label}. Am I disappointed?"

        polar_z=sia.polarity_scores(z_label)
        polar_x=sia.polarity_scores(x_label)
        if polar_z['neg']>0 or polar_x['neg']>0:
            neg_sent=sent
        else:
            pos_sent=sent
                
        if x_label in blacklist_label:
            continue

        if [_.lemma_ for _ in x_tokens]==[_.lemma_ for _ in tokens]:
            continue

        if pos_sent:
            w.write(json.dumps({'story': pos_sent, 'answer': 'Y'}) + '\n')
        elif neg_sent:
            w.write(json.dumps({'story': pos_sent, 'answer': 'N'}) + '\n')


        filter_content_case1.append(line)

100%|██████████| 49966/49966 [04:55<00:00, 169.13it/s]


In [86]:
print(f"Number of lines before filter: {len(new_content_case1)}")
print(f"Number of lines after filter: {len(filter_content_case1)}")

Number of lines before filter: 49966
Number of lines after filter: 46817


In [80]:
for n in new_content_case1:
    if n not in filter_content_case1:
        print(n)
        input('c?')

['/c/en/00t_shirts', '/c/en/drawer', '/c/en/den']


c? 


['/c/en/00t_shirts', '/c/en/drawer', '/c/en/kitchen']


c? 


['/c/en/32_teeth', '/c/en/mouth', '/c/en/river']


c? 


['/c/en/abandoned_tractor', '/c/en/meadow', '/c/en/countryside']


c? 


['/c/en/address_label', '/c/en/drawer', '/c/en/den']


KeyboardInterrupt: Interrupted by user

## Case3: Object modifications
There was `X` in the `Y`. What can it do to the `Z1n`? `Z1v` it.

In [65]:
# data file location

data_case31="output/case3_1.gz"
data_case32="output/case3_2.gz"


case3_output='output/case3.jsonl'

In [67]:
%%bash
kgtk query --debug -i input/cskg_renamed.tsv.gz --match '(z)<-[:`/r/UsedFor`]-(x)-[:`/r/AtLocation`]->(y)' --return 'z as needed, x as place, y as Y' -o output/case3_2.gz
kgtk query --debug -i input/cskg_renamed.tsv.gz --match '(z)<-[:`/r/CapableOf`]-(x)-[:`/r/AtLocation`]->(y)' --return 'z as needed, x as place, y as Y' -o output/case3_1.gz

[2021-05-26 10:47:42 query]: SQL Translation:
---------------------------------------------
  SELECT graph_1_c1."node2" "_aLias.needed", graph_1_c1."node1" "_aLias.place", graph_1_c2."node2" "_aLias.Y"
     FROM graph_1 AS graph_1_c1, graph_1 AS graph_1_c2
     WHERE graph_1_c1."label"=?
     AND graph_1_c2."label"=?
     AND graph_1_c1."node1"=graph_1_c2."node1"
  PARAS: ['/r/UsedFor', '/r/AtLocation']
---------------------------------------------
[2021-05-26 10:47:44 query]: SQL Translation:
---------------------------------------------
  SELECT graph_1_c1."node2" "_aLias.needed", graph_1_c2."node1" "_aLias.place", graph_1_c2."node2" "_aLias.Y"
     FROM graph_1 AS graph_1_c1, graph_1 AS graph_1_c2
     WHERE graph_1_c1."label"=?
     AND graph_1_c2."label"=?
     AND graph_1_c1."node1"=graph_1_c2."node1"
  PARAS: ['/r/CapableOf', '/r/AtLocation']
---------------------------------------------


In [68]:
# load kgtk file
with gzip.open(data_case31) as f:
    head_case1=f.readline()
    content_case3=[]
    for line in f:
        content_case3.append(line.strip().decode("utf-8").split("\t"))
    
with gzip.open(data_case32) as f:

    head_case1=f.readline()
    content_case3=[]
    for line in f:
        content_case3.append(line.strip().decode("utf-8").split("\t"))

In [69]:
# remove duplicate
# content line format: [z, x, y]
content_case3_set=set()

for line in content_case3:
    z,x,y=line
    
    if (z,x,y) not in content_case3_set and (z,x,y) not in content_case3_set:
        content_case3_set.add((z,x,y))

In [70]:
# before removing
print(f"Number of lines: {len(content_case3)}")
# after removing
print(f"Number of lines after removing duplicate: {len(content_case3_set)}")

Number of lines: 214878
Number of lines after removing duplicate: 214878


In [71]:
# from content obtain the frequency of each location
loc_distribution=dict()
for line in content_case3:
    location=line[2]
    
    loc_distribution[location]=loc_distribution.get(location,0)+1

print("Number of location type:", len(loc_distribution))

# only choose top 100 locations
num_loc=100
print("Number of chosen location:", num_loc)    

loc_sort=sorted(loc_distribution.items(),key=lambda k:k[1], reverse=True)
loc_chosen=set([_[0] for _ in loc_sort[:num_loc]])

Number of location type: 5166
Number of chosen location: 100


In [72]:
# filter contenct by loc_chosen
new_content_case3=[]

for line in tqdm(content_case3_set):
    location=line[2]
    if location in loc_chosen:
        new_content_case3.append(line)

100%|██████████| 214878/214878 [00:00<00:00, 1770613.77it/s]


In [73]:
print(f"Number of lines for all locations: {len(content_case3_set)}")
print(f"Number of lines for top 100 locations: {len(new_content_case3)}")

Number of lines for all locations: 214878
Number of lines for top 100 locations: 69509


In [79]:
# filter content
# z as plan, c as alt, x as goal, y as loc
"""
Output:
loc_line:{location_id:[sentence,
                        node bert embedding similarity,
                        tranE Embedding similarity,
                        complex embedding similarity,
                        y label name,
                        z label name]}
"""
loc_line_case3=defaultdict(set)
with open(case3_output, 'w') as w:
    for i in tqdm(range(len(new_content_case3))):
        line=new_content_case3[i]

        y_id=line[2]
        x_id=line[1]
        z_id=line[0]
        x_label=cskg_renamed_dist[x_id]
        y_label=cskg_renamed_dist[y_id]
        z_label=cskg_renamed_dist[z_id]
        # filter line by X, POS should be (verb)
        z_tokens=text2token(z_label)

        if len(z_tokens)<2:
            continue
        else:
            zv=z_tokens[0]
            zn_tokens=z_tokens[1:]

            y_embed=np.array(eval("["+cskg_word_embeddings[y_id]+"]"))
            zn_text=" ".join([_.lemma_ for _ in zn_tokens])
            zn_id="/c/en/"+zn_text.replace(" ","_")

    #         # find embed.
    #         z_embed=np.array(eval("["+cskg_word_embeddings[z_id]+"]"))

    #         z_embed_transE=cskg_tranE_embedding[line[0]]
    #         y_embed_transE=cskg_tranE_embedding[line[1]]

    #         z_embed_complex=cskg_complex_embedding[line[0]]
    #         y_embed_complex=cskg_complex_embedding[line[1]]

    #         cskg_similar=cosine_similar(z_embed,y_embed)
    #         tranE_similar=cosine_similar(z_embed_transE,y_embed_transE)
    #         complex_similar=cosine_similar(z_embed_complex,y_embed_complex)

            # remove item in blacklist
            if x_label in blacklist_label or zn_text in blacklist_label:
                continue

            # remove same item
            if x_label==zn_text:
                continue

            # only verb noun structure are allowed
            if zv.pos_ == "VERB" and all([_.pos_=="NOUN" for _ in zn_tokens]):

                temp=loc_line_case3[y_id]
                sent=f"There was {x_label} in the {y_label}. What can it do to the {zn_text}?"
                w.write(json.dumps({'story': sent, 'answer': f'{zv.lemma_} it.'}) + '\n')
    #             temp.add((sent,cskg_similar,tranE_similar,complex_similar,y_label,z_label))
    #             loc_line_case3[y_id]=temp
            else:
                continue

100%|██████████| 69509/69509 [05:03<00:00, 229.22it/s]
