In [138]:
import string
import csv
import numpy as np
from sklearn.metrics import pairwise_distances
import spacy
from spacy.lang.en import English
import fasttext
import rdflib
import re
import editdistance
import pickle

# import en_core_web_trf
# import fasttext.util

In [172]:
class KG_handler:
    def __init__(self):
        '''
        ### Notices 
            - all2lbl has more "P" item than ent2id dose (ent2lbl:298, ent2id:47)
            - all2lbl contain all entities and relations, while ent2id lacks some relations
            - labels of entity are not unique, but entity URIs are unique
            - labels of relations are unique (at least unique till now)     
            - to obtain query, we need movie name (entity label), relation label and relation URI
        '''
        
        # language model
        print('step 3')
        self.LM = fasttext.load_model('utils/cc.en.300.bin')
        try:
            # self.spacy_model = en_core_web_trf.load()
            self.spacy_model = spacy.load('en_core_web_trf')
        except OSError:
            print('Can\'t find model, please run command "python -m spacy download en_core_web_trf" to download it and restart the program')

        
        
        
        # RDF config
        print('step 1')
        self.graph = rdflib.Graph().parse('utils/14_graph.nt', format='turtle')
        self.WD = rdflib.Namespace('http://www.wikidata.org/entity/')
        self.WDT = rdflib.Namespace('http://www.wikidata.org/prop/direct/')
        self.DDIS = rdflib.Namespace('http://ddis.ch/atai/')
        self.RDFS = rdflib.namespace.RDFS
        self.SCHEMA = rdflib.Namespace('http://schema.org/')      
        
          
        # label relevant
        print('step 2')
        self.all2lbl = {ent: str(lbl) for ent, lbl in self.graph.subject_objects(self.RDFS.label)}
        self.rel2lbl = {k:v for k, v in self.all2lbl.items() if self._is_rel(k)}
        self.lbl2rel = {lbl: ent for ent, lbl in self.rel2lbl.items()}
        self.rel_lbl_set = set([v for v in self.rel2lbl.values()])
        self.ent2lbl = {k:v for k, v in self.all2lbl.items() if self._is_ent(k)}
        self.lbl2ent = {lbl: ent for ent, lbl in self.ent2lbl.items()}  # not unique
        self.ent_lbl_set = set([v for v in self.ent2lbl.values()])
        # with open('utils/ent_lbl2vec.pkl', 'rb') as f:
        #     self.ent_lbl2vec = pickle.load(f)
        # with open('utils/rel_lbl2vec.pkl', 'rb') as f:
        #     self.rel_lbl2vec = pickle.load(f)
        self.ent_lbl2vec = {k:self.LM.get_word_vector(k) for k in self.ent_lbl_set}
        self.rel_lbl2vec = {k:self.LM.get_word_vector(k) for k in self.rel_lbl_set}
        # self.all_lbl_set
        
        # embedding relevant
        # with open('utils/relation_ids.del', 'r') as ifile:
        #     self.rel2id = {rdflib.term.URIRef(rel): int(idx) for idx, rel in csv.reader(ifile, delimiter='\t')}
        # self.id2rel = {v: k for k, v in self.rel2id.items()}
        # with open('utils/entity_ids.del', 'r') as ifile:
        #     self.ent2id = {rdflib.term.URIRef(ent): int(idx) for idx, ent in csv.reader(ifile, delimiter='\t')}
        # self.id2ent = {v: k for k, v in self.ent2id.items()}
        # self.ent_emb = np.load('utils/entity_embeds.npy')
        # self.rel_emb = np.load('utils/relation_embeds.npy')
        
        self.synonyms_dict = {
            'cast member' :['actor', 'actress', 'cast'],
            'genre': ['type', 'kind'],
            'publication date': ['release', 'date', 'airdate', 'publication', 'launch', 'broadcast','released','launched'],
            'executive producer': ['showrunner'],
            'screenwriter': ['scriptwriter', 'screenplay', 'teleplay', 'writer', 'script', 'scenarist', 'story'],
            'director of photography': ['cinematographer', 'DOP', 'dop'],
            'film editor': ['editor'],
            'production designer': ['designer'],
            'box office': ['box', 'office', 'funding'],
            'cost': ['budget', 'cost'],
            'nominated for': ['nomination', 'award', 'finalist', 'shortlist', 'selection'],
            'costume designer': ['costume'],
            'official website': ['website', 'site'],
            'filming location': ['flocation'],
            'narrative website': ['nlocation'],
            'production company': ['company'],
            'country of origin': ['origin', 'country'],
            '–' : ['-']
        }
        self.replacement_dict = {k:v for k,  v_list in self.synonyms_dict.items() for v in v_list }
        
    def _ruler_based(self, query:str):
        '''
        Param:
            query: natrual language query from the user
        Return:
            {'ent_lbl': None | str, 'rel_lbl': None | str, 'rel_postfix': None | str}:
            diction of extraction results. if not found than set as None
        '''
        
        res = {'ent_lbl': None,
               'rel': None,
               'rel_lbl': None,
               'rel_postfix': None}
        
        # pre-proecssing
        tokens = self._replace(query)
        
        # get all word combination that match existed entities
        word_seq = [' '.join(tokens[i:j+1]) for i in range(len(tokens)) for j in range(i, len(tokens))]
        matched_seq = [seq for seq in word_seq if seq in (self.ent_lbl_set | self.rel_lbl_set)]      # all exited entities that appear in the sentence
        
        # extraction 
        ent_candidates = []
        for seq in matched_seq:
            if seq in self.ent_lbl_set:
                ent_candidates.append(seq)
            # detected relation
            if seq in self.rel_lbl_set:
                # which means there are two possible word for relations
                if res['rel_lbl'] != None:
                    print("WARNIND: multiple possible relations detected...")
                res['rel'] = seq
                res['rel_lbl'] = self.lbl2rel[seq]
                res['rel_postfix'] = self._get_rel_label(res['rel_lbl'])
        # process possible entities
        ent_candidates = sorted(ent_candidates, key = lambda x:len(x[0]), reverse=True)
        res['ent_lbl'] = ent_candidates[0]
    
        return res
    
    def _similarity_based(self, query, top_k_rel=10, top_k_ent=1):
        '''
        similar to _rule_based(), but return top k similar results in list form, which is:
        Return:
            {'ent_lbl': list(str), 'rel_lbl': list(str), 'rel_postfix': list(str)}:
        '''
        
        res = {'ent_lbl': None,
               'rel': None,
               'rel_lbl': None,
               'rel_postfix': None}
        
        # pre-processing
        tokens = self._replace(query).join()
        
        # token process
        doc = self.spacy_model(tokens)
        
        
        ent = []    # entity
        rel = []    # relation
        for token in doc:
            # print(token)
            if (token.ent_iob_ != "O"):
                ent.append(token.lemma_)
            elif (token.pos_=='NOUN') | (token.pos_=="VERB"):
                rel.append(token.lemma_)
                
                
        # find the closest relation information
        rel_wv = np.array([i for i in self.rel_lbl2vec.values()])
        rel_lbl = [i for i in self.rel_lbl2vec.keys()]
        rel = " ".join(rel)
        wv = self.LM.get_word_vector(rel).reshape((1,-1))
        dist = pairwise_distances(wv, rel_wv).flatten()
        closest_rel_idx = dist.argsort()[:top_k_rel]
        closest_rel_lbl = [rel_lbl[i] for i in closest_rel_idx]
        closest_rel_uri = [self.lbl2rel[i] for i in closest_rel_lbl]
        
        
        # find the closest ent information
        extracted_ent = " ".join(ent)
        ent_wv = np.array([i for i in self.ent_lbl2vec.values()])
        ent_lbl = [i for i in self.ent_lbl2vec.keys()]
        wv = self.LM.get_word_vector(extracted_ent).reshape((1,-1))
        dist = pairwise_distances(wv, ent_wv).flatten()
        closest_ent_idx = dist.argsort()[:top_k_ent]
        closest_ent_lbl = [ent_lbl[i] for i in closest_ent_idx]

        res['rel'] = rel
        res['ent_lbl'] = closest_ent_lbl[0]
        res['rel_lbl'] = closest_rel_lbl
        res['rel_postfix'] = [self._get_rel_label(uri) for uri in closest_rel_uri]

        return res
    
    
    def _replace(self, sent:str) -> list :
        '''
        replace words in sentence if they have relevant replacement in the dictionary
        '''
        
        # remove all possible punctuation in the end of sentence
        cleaned_sent = sent.rstrip(string.punctuation + ' ')
        tokens = cleaned_sent.split()
        tokens = [ self.replacement_dict[token] if token in self.replacement_dict.keys() else token for token in tokens ]
        
        return tokens
    
    def get_query_res(self, user_input):
        
        '''
        Return:
            2D-list (n_queries, n_cols_per_query)
        '''
        
        ruler_attemp = self._ruler_based(user_input)
        # print(ruler_attemp)

        # calling backup extraction strategy
        if None in ruler_attemp.values():
            similarity_attemp = self._similarity_based(user_input)
            for k,v in ruler_attemp.items():
                if v == None:
                    ruler_attemp[k] = similarity_attemp[k]    
        similarity_attemp = self._similarity_based(user_input)


        # print(similarity_attemp)

        # grab result from graph
        for i in range(len(ruler_attemp['rel_postfix'])):
            movie_name = ruler_attemp['ent_lbl']
            target_label = ruler_attemp['rel_postfix'][i]
            target_name = ruler_attemp['rel_lbl'][i]
            if "date" in target_name:
                query = f'''PREFIX ddis: <http://ddis.ch/atai/>

                PREFIX wd: <http://www.wikidata.org/entity/>

                PREFIX wdt: <http://www.wikidata.org/prop/direct/>

                PREFIX schema: <http://schema.org/>

                SELECT ?date WHERE {{
                    ?movie rdfs:label "{movie_name}"@en.

                    ?movie wdt:{target_label} ?date

                }} LIMIT 1'''
            else:
                query = f'''PREFIX ddis: <http://ddis.ch/atai/>

                PREFIX wd: <http://www.wikidata.org/entity/>

                PREFIX wdt: <http://www.wikidata.org/prop/direct/>

                PREFIX schema: <http://schema.org/>

                SELECT ?lbl WHERE {{
                    ?sub rdfs:label "{movie_name}"@en.

                    ?sub wdt:{target_label} ?obj.

                    ?obj rdfs:label ?lbl.

                }} LIMIT 1'''
            query  = query.strip()
            # print(query)
            
            res = []
            for row in KG_handler.graph.query(query):
                res.append([str(i) for i in row]) 
            # print(res)
            if len(res) != 0:
                break
            
        return res    
    
    
    def _get_rel_label(self, URI):
        return str(URI).split('/')[-1]
    
    def _is_rel(self, URI):
        label = self._get_rel_label(URI)
        return label[0] == 'P'
    
    def _is_ent(self, URI):
        label = self._get_rel_label(URI)
        return label[0] == 'Q'   
    
    


In [173]:
KG_handler = KG_handler()

step 3
step 1
step 2


In [231]:
# for i in KG_handler.ent_lbl_set:
#     if 'Star Wars Episode' in i:
#         print(i)
# for i in KG_handler.ent_lbl_set:
#     if 'Star Wars: Episode' in i:
#         print(i)

Star Wars Episode IX: The Rise of Skywalker
Star Wars Episode III: Revenge of the Sith
Star Wars Episode I: The Phantom Menace
Star Wars Episode VI: Return of the Jedi
Star Wars Episode VII: The Force Awakens
Star Wars Episode I: Battle for Naboo
Star Wars Episode II: Attack of the Clones
Star Wars Episode V: The Empire Strikes Back
Star Wars: Episode II – Attack of the Clones
Star Wars: Episode III – Revenge of the Sith (soundtrack)
Star Wars: Episode I – The Phantom Menace
Star Wars: Episode I – The Phantom Menace (soundtrack)
Star Wars: Episode VI – Return of the Jedi
Star Wars: Episode III – Revenge of the Sith
Star Wars: Episode VIII – The Last Jedi
Star Wars: Episode IV – A New Hope
Star Wars: Episode V – The Empire Strikes Back


In [240]:
def _similarity_based(model, query, top_k_rel=10, top_k_ent=10):
    '''
    similar to _rule_based(), but return top k similar results in list form, which is:
    Return:
        {'ent_lbl': list(str), 'rel_lbl': list(str), 'rel_postfix': list(str)}:
    '''


    res = {'ent_lbl': None,
            'rel': None,
            'rel_lbl': None,
            'rel_postfix': None}
    replacement_dict = {v:k for k,  v_list in model.synonyms_dict.items() for v in v_list }
    replacement_dict['-'] = '–'

    
    # remove all possible punctuation in the end of sentence
    cleaned_sentence = query.rstrip(string.punctuation+' ')
    # print(cleaned_sentence)
    tokens = cleaned_sentence.split()
    tokens = [ replacement_dict[token] if token in replacement_dict.keys() else token for token in tokens ]
    doc = model.spacy_model(' '.join(tokens))
    
    ent = []    # entity
    rel = []    # relation


    for token in doc:
        # print(token, token.ent_iob_)
        if (token.ent_iob_ != "O"):
            ent.append(token.lemma_)
        elif (token.pos_=='NOUN') | (token.pos_=="VERB"):
            rel.append(token.lemma_)
            
    # find the closest relation information
    rel_wv = np.array([i for i in model.rel_lbl2vec.values()])
    rel_lbl = [i for i in model.rel_lbl2vec.keys()]
    rel = " ".join(rel)
    # rel = 'release'
    wv = model.LM.get_word_vector(rel).reshape((1,-1))
    dist = pairwise_distances(wv, rel_wv).flatten()
    closest_rel_idx = dist.argsort()[:top_k_rel]
    closest_rel_lbl = [rel_lbl[i] for i in closest_rel_idx]
    closest_rel_uri = [model.lbl2rel[i] for i in closest_rel_lbl]
    
    # find the closest ent information
    extracted_ent = " ".join(ent)
    # print(extracted_ent)
    ent_wv = np.array([i for i in model.ent_lbl2vec.values()])
    ent_lbl = [i for i in model.ent_lbl2vec.keys()]
    wv = model.LM.get_word_vector(extracted_ent).reshape((1,-1))
    dist = pairwise_distances(wv, ent_wv).flatten()
    closest_ent_idx = dist.argsort()[:top_k_ent]
    # print(dist[dist.argsort()[:top_k_ent]])
    closest_ent_lbl = [ent_lbl[i] for i in closest_ent_idx]
    print(closest_ent_lbl)
    # print(closest_ent_lbl)

    res['rel'] = rel
    res['ent_lbl'] = closest_ent_lbl[0]
    res['rel_lbl'] = closest_rel_lbl
    res['rel_postfix'] = [model._get_rel_label(uri) for uri in closest_rel_uri]

    return res

def _ruler_based(model, query:str):
    '''
    Param:
        query: natrual language query from the user
    Return:
        {'ent_lbl': None | str, 'rel_lbl': None | str, 'rel_postfix': None | str}:
        diction of extraction results. if not found than set as None
    '''
    
    res = {'ent_lbl': None,
            'rel': None,
            'rel_lbl': None,
            'rel_postfix': None}
    replacement_dict = {v:k for k,  v_list in model.synonyms_dict.items() for v in v_list }
    replacement_dict['-'] = '–'

    
    # remove all possible punctuation in the end of sentence
    cleaned_sentence = query.rstrip(string.punctuation+' ')
    # print(cleaned_sentence)
    tokens = cleaned_sentence.split()
    
    # matching possible existing entities and relations from sentence
    tokens = [ replacement_dict[token] if token in replacement_dict.keys() else token for token in tokens ]
    word_seq = [' '.join(tokens[i:j+1]) for i in range(len(tokens)) for j in range(i, len(tokens))]
    matched_seq = [seq for seq in word_seq if seq in (model.ent_lbl_set | model.rel_lbl_set)]      # all exited entities that appear in the sentence

    # print(tokens)
    word_seq = [' '.join(tokens[i:j+1]) for i in range(len(tokens)) for j in range(i, len(tokens))]
    # print(word_seq)
    matched_seq = [seq for seq in word_seq if seq in (model.ent_lbl_set | model.rel_lbl_set)]      # all exited entities that appear in the sentence
    # print(matched_seq)
    
    
    # extraction 
    ent_candidates = []
    for seq in matched_seq:
        if seq in model.ent_lbl_set:
            ent_candidates.append(seq)
        # detected relation
        if seq in model.rel_lbl_set:
            # which means there are two possible word for relations
            if res['rel_lbl'] != None:
                print("WARNIND: multiple possible relations detected...")
            res['rel'] = seq
            res['rel_lbl'] = [seq]
            res['rel_postfix'] = [model._get_rel_label(model.lbl2rel[seq])]
            
            
    # extract entity from candidates
    ent_candidates = sorted(ent_candidates, key = lambda x:len(x), reverse=True)
    print(ent_candidates)
    res['ent_lbl'] = ent_candidates[0]
    

    return res


In [242]:
with open('utils/relation_ids.del', 'r') as ifile:
    rel2id = {rdflib.term.URIRef(rel): int(idx) for idx, rel in csv.reader(ifile, delimiter='\t')}
id2rel = {v: k for k, v in rel2id.items()}
with open('utils/entity_ids.del', 'r') as ifile:
    ent2id = {rdflib.term.URIRef(ent): int(idx) for idx, ent in csv.reader(ifile, delimiter='\t')}
id2ent = {v: k for k, v in ent2id.items()}
ent_emb = np.load('utils/entity_embeds.npy')
rel_emb = np.load('utils/relation_embeds.npy')


In [245]:
# testing cases
# user_query = "Who is the screenwriter of The Masked Gang: Cyprus? "
user_query = 'Who is the director of Star Wars: Episode VI - Return of the Jedi? '
# user_query = 'When was "The Godfather" release?' 
# user_query = 'What is the publication date of "The Godfather" ? '

ruler_attemp = _ruler_based(KG_handler, user_query)
# print(ruler_attemp)

# calling backup extraction strategy
if None in ruler_attemp.values():
    similarity_attemp = _similarity_based(KG_handler, user_query)
    for k,v in ruler_attemp.items():
        if v == None:
            ruler_attemp[k] = similarity_attemp[k]
            
similarity_attemp = _similarity_based(KG_handler, user_query)
print(similarity_attemp)
# print(similarity_attemp)
# print(ruler_attemp)
ruler_attemp = similarity_attemp
# grab result from graph
for i in range(len(ruler_attemp['rel_postfix'])):
    movie_name = ruler_attemp['ent_lbl']
    target_label = ruler_attemp['rel_postfix'][i]
    target_name = ruler_attemp['rel_lbl'][i]
    if "date" in target_name:
        query = f'''PREFIX ddis: <http://ddis.ch/atai/>

        PREFIX wd: <http://www.wikidata.org/entity/>

        PREFIX wdt: <http://www.wikidata.org/prop/direct/>

        PREFIX schema: <http://schema.org/>

        SELECT ?date WHERE {{
            ?movie rdfs:label "{movie_name}"@en.

            ?movie wdt:{target_label} ?date

        }} LIMIT 1'''
    else:
        query = f'''PREFIX ddis: <http://ddis.ch/atai/>

        PREFIX wd: <http://www.wikidata.org/entity/>

        PREFIX wdt: <http://www.wikidata.org/prop/direct/>

        PREFIX schema: <http://schema.org/>

        SELECT ?lbl WHERE {{
            ?sub rdfs:label "{movie_name}"@en.

            ?sub wdt:{target_label} ?obj.

            ?obj rdfs:label ?lbl.

        }} LIMIT 1'''
    query  = query.strip()
    print(query)
    res = []
    for row in KG_handler.graph.query(query):
        res.append([str(i) for i in row]) 
    print(res)
    # stop if any result grabed 
    if False:
        break
    
# if no any query result match, using embeddings
ent_id = ent2id[KG_handler.lbl2ent[ruler_attemp['ent_lbl']]]
rel_id = rel2id[KG_handler.lbl2rel[ruler_attemp['rel_lbl'][0]]]
head = ent_emb[ent_id]
pred = rel_emb[rel_id]

lhs = (head - pred).reshape((1,-1))

dist = pairwise_distances(lhs, ent_emb).flatten()
most_likely_idx = dist.argsort()[0]
closest_ent = KG_handler.ent2lbl[id2ent[most_likely_idx]]

print(closest_ent)

['Star Wars: Episode VI – Return of the Jedi', 'director', 'Return', 'Jedi', 'Who']
['Star Wars: Episode VI – Return of the Jedi', 'Star Wars Episode VI: Return of the Jedi', 'Star Wars: Return of the Jedi', 'Star Wars Episode II: Attack of the Clones', 'Star Wars: Episode III – Revenge of the Sith', 'Star Wars: Episode II – Attack of the Clones', 'Hidden Figures: The American Dream and the Untold Story of the Black Women Mathematicians Who Helped Win the Space Race', 'The Return of the Musketeers, or The Treasures of Cardinal Mazarin', 'Star Wars Episode IX: The Rise of Skywalker', "Molly's Game: The True Story of the 26-Year-Old Woman Behind the Most Exclusive, High-Stakes Underground Poker Game in the World"]
{'ent_lbl': 'Star Wars: Episode VI – Return of the Jedi', 'rel': 'director', 'rel_lbl': ['director', 'screenwriter', 'choreographer', 'KMRB film rating', 'described by source', 'art director', 'parent organization', 'INCAA film rating', 'headquarters location', 'operating syste

In [None]:
lm = KG_handler.LM
for i in KG_handler.rel_lbl_set:
    # print(i)
    KG_handler.rel_lbl2vec[i] = lm.get_sentence_vector(i)
# lm.get_sentence_vector('')


operator
different from
Minkultury film ID
designed by
shares border with
director of photography
narrator
publisher
replaces
director
director / manager
published in
takes place in fictional universe
cast member
list of works
time period
interested in
main subject
allegiance
lifestyle
continent
original film format
author
stepparent
capital of
performer
BBFC rating
operating area
contributor to the creative work or subject
YouTube video ID
OFLC classification
choreographer
followed by
superhuman feature or ability
voice actor
platform
occupant
mother
assessment
ethnic group
relative
medical condition
intended public
place of birth
after a work by
IMDb ID
founded by
work location
field of work
based on
inspired by
Disney A to Z ID
iTunes movie ID
participant in
religion
collection
copyright holder
sports discipline competed in
Filmiroda rating
Writers Guild of America project ID
has works in the collection
The Criterion Collection film ID
given name
has quality
country
copyright status