In [1]:
import json
from collections import Counter
import pandas as pd

In [2]:
from py2neo import Node, Relationship, Graph
from tqdm.auto import tqdm

## Preconfiguring db access and clearing existing database

In [3]:
credentials = {
    'host': '####',
    'username': '###',
    'password': '###'
}

In [4]:
graph: Graph = Graph(credentials['host'], auth=(credentials['username'], credentials['password']))

In [5]:
def nuke_database():
    graph.run("MATCH (n) DETACH DELETE n")

In [6]:
nuke_database()

## Reading clusters and getting rid of unannotated clusters(optional)

In [7]:
with open("./data/clusters_examples.json", 'r', encoding='utf-8') as f:
    cluster_triplets = json.load(f)
    
with open("./data/clusters_annotation.json", 'r', encoding='utf-8') as f:
    cluster_annot = json.load(f)

In [8]:
len(cluster_annot.keys()), len(cluster_triplets.keys())

(55, 99)

In [9]:
common_keys = set(cluster_triplets.keys()).intersection(set(cluster_annot.keys()))

In [10]:
#mapping keys from str -> int, then sorting them, then mapping back int -> str
common_keys = list(
    map(str, sorted(
        list(map(int, common_keys))
    ))
)

## Separating mapping into triplets and their predicates

In [11]:
cluster2predicates = {}
triplets = []
for key in tqdm(common_keys):
    
    cluster = cluster_triplets[key]
    predicates = list(cluster['predicates'].keys())
    rel_type = cluster_annot[key]['rel_type']
    cluster2predicates[rel_type] = predicates
    
    for triplet in cluster['data']:
        reformatted_triplet = {
            'subject': triplet[0], # who/what is acting
            'object': triplet[1], # who/what is a target/part of action
            'rel_type': rel_type
        }
        triplets.append(reformatted_triplet)

HBox(children=(IntProgress(value=0, max=55), HTML(value='')))




In [34]:
def triplet_representation(triplet):
    return f"({triplet['subject']})-[{triplet['rel_type']}]->({triplet['object']})" 

## Some stats

In [13]:
print(f"Total num of triplets: {len(triplets)}")

Total num of triplets: 10899


In [14]:
Counter(map(triplet_representation, triplets)).most_common(10)

[('(Investigation)-[founder]->(FBI)', 24),
 ('(Portugal)-[released]->(Russia)', 7),
 ('(Nadhani)-[developed]->(three years)', 5),
 ('(Nadhani)-[developed]->(first three years)', 5),
 ('(SAP)-[released]->(SAP)', 5),
 ('(WikiLeaks)-[asking]->(Trump Jr.)', 4),
 ('(Oracle Corporation)-[acquire]->(Sun Microsystems)', 4),
 ('(Oracle)-[acquire]->(Sun Microsystems)', 4),
 ('(Sega)-[who_is]->(Japan)', 4),
 ('(Sundar Pichai)-[ceo_at]->(Google)', 4)]

In [15]:
print(f"Num of uniq triplets: {len(set(map(triplet_representation, triplets)))}")

Num of uniq triplets: 9943


## Removing duplicates

In [16]:
triplets = list(
    map(json.loads, set( # removing duplicates and mapping back to python dicts
        map(json.dumps, triplets)) # mapping into hashable type: str
    )
)

## Inserting into database (in object-relation-object structure)

In [17]:
object_mapping = {}

for subject in set(obj['subject'] for obj in triplets):
    object_mapping[subject] = Node('object', name=subject)
    #graph.create(subject_mapping[subject])
    
for obj in set(obj['object'] for obj in triplets):
    if obj not in object_mapping.keys():
        object_mapping[obj] = Node('object', name=obj)
    #graph.create(object_mapping[obj])
    
for obj in tqdm(triplets):
    
    object_node = object_mapping[obj['object']]
    subject_node = object_mapping[obj['subject']]
    #print(object_node)
    
    relation_type = obj['rel_type']
    relationship = Relationship.type("predicate")
    rel = relationship(subject_node, object_node, type=relation_type)
    
    graph.create(rel)

HBox(children=(IntProgress(value=0, max=9943), HTML(value='')))




## That is how to perform queries

In [38]:
def query(subject_name = None, relation_type = None, object_name = None, verbose=False):
    """
    At least two arguments should not be None for adequate amount of results
    """
    query = "MATCH (s:object)-[r:predicate]->(o:object) [WHERE_CLAUSE]RETURN s.name as subject, r.type as rel_type, o.name as object"
    
    clause = "WHERE "
    used = False
    if subject_name is not None:
        clause += f"s.name = '{subject_name}' "
        used = True
    if relation_type is not None:
        clause_item = f"r.type = '{relation_type}' "
        if used:
            clause_item = "AND " + clause_item
            
        clause += clause_item
        used = True
    if object_name is not None:
        clause_item = f"o.name = '{object_name}' "
        if used:
            clause_item = "AND " + clause_item
            
        clause += clause_item
        used = True
        
    if used:
        query = query.replace("[WHERE_CLAUSE]", clause)
    else:
        query = query.replace("[WHERE_CLAUSE]", '')
        
    if verbose:
        print(query)
        
    return list(map(triplet_representation, graph.run(query).to_data_frame().to_dict(orient='records')))

In [72]:
cluster2predicates['ceo_at']

['CEO of',
 'is CEO of',
 "'s CEO is",
 'was CEO of',
 'was',
 'CEO at',
 'then-CEO of',
 'new CEO of',
 'CEO in',
 'inventor of',
 'is CEO at',
 'sent back',
 'CEOs of',
 'was CEO until',
 'be CEO at',
 'is',
 'was CEO From',
 'CEO until',
 'was announced On']

In [74]:
cluster2predicates['heads']

['replaced',
 'president of',
 'was appointed',
 'chairman of',
 'President of',
 'son of',
 'was president of',
 'was replaced by',
 'is son of',
 'was appointed as',
 'replacing',
 'was appointed in',
 'was appointed In',
 'president at',
 'was replaced In',
 'was appointed by',
 "'s president is",
 'proved In',
 'appointed',
 'was president for',
 'was appointed at_time',
 "'s President is",
 'president from',
 "'s son is",
 'President at',
 'be replaced by',
 'replaced Phillips as',
 'of chairman is',
 'being replaced by',
 'was appointed at',
 'vice-president of',
 'was',
 'been appointed',
 'unanimously appointed as',
 'was replaced on',
 'proved in',
 'replaced Irimajiri as',
 'has replaced',
 'appointed as',
 'was Chairman of',
 'Chairman of',
 'be replaced',
 'rotundity of',
 'was replaced as',
 'since President is',
 'is chairman of',
 'agreed On',
 'was appointed On',
 'replaced Yang as',
 'replaced President with',
 'were replaced by',
 'was son of',
 'be replaced in',
 'is

In [55]:
query(subject_name='Microsoft')

['(Microsoft)-[released]->(Zune)',
 '(Microsoft)-[collaborate]->(1994)',
 '(Microsoft)-[announced]->(18 January 2017)',
 '(Microsoft)-[launch]->(November 2015)',
 '(Microsoft)-[used_in]->(Messenger)',
 '(Microsoft)-[start_in]->(Windows Vista)',
 '(Microsoft)-[announced]->(19 July 2012)',
 '(Microsoft)-[announced]->(August 2015)',
 '(Microsoft)-[acquire]->(Acompli)',
 '(Microsoft)-[sued]->(1998)',
 '(Microsoft)-[who_is]->(February 2009)',
 '(Microsoft)-[announced]->(July 2009)',
 '(Microsoft)-[announced]->(May 2017)',
 '(Microsoft)-[announced]->(April 2015)',
 '(Microsoft)-[announced]->(May 12)',
 '(Microsoft)-[announced]->(July 2015)',
 '(Microsoft)-[acquire]->(May 2009)',
 '(Microsoft)-[announced]->(September 30)',
 '(Microsoft)-[released]->(year prior)',
 '(Microsoft)-[announced]->(11 November 2014)',
 '(Microsoft)-[introducing]->(late 1988)',
 '(Microsoft)-[purchased]->(Funk)',
 '(Microsoft)-[launch]->(Windows Phone Community)',
 '(Microsoft)-[developed]->(1996)',
 '(Microsoft)-[ann

In [56]:
relations_str = json.dumps(cluster2predicates)

## Writing stuff for api

In [57]:
%%writefile qa_service.py

from bentoml import BentoService, api, env, artifacts
from bentoml.artifact import TextFileArtifact
from bentoml.handlers import JsonHandler

import flair
from flair.data import Sentence
from flair.models import SequenceTagger

from py2neo import Graph

import json

import pandas as pd

@env(pip_dependencies=['flair', 'torch', 'pandas', 'py2neo', 'numpy'])
@artifacts([TextFileArtifact('relations')])
class QAService(BentoService):
    
    def get_entities(self, the_question, model):
        the_sentenced_question = Sentence(the_question)
        model.predict(the_sentenced_question)
        spans = [span for span in the_sentenced_question.get_spans('ner') if span.tag == "PER" or span.tag == "MISC" or span.tag == 'LOC']
        entities = [" ".join([tok.text for tok in span.tokens]) for span in spans]
        return entities
    
    def query(self, graph, subject_name = None, relation_type = None, object_name = None):
        """
        At least two arguments should not be None for adequate amount of results
        """
        query = "MATCH (s:object)-[r:predicate]->(o:object) [WHERE_CLAUSE]RETURN s.name as subject, r.type as relation, o.name as object"

        clause = "WHERE "
        used = False
        if subject_name is not None:
            clause += f"s.name = '{subject_name}' "
            used = True
        if relation_type is not None:
            clause_item = f"r.type = '{relation_type}' "
            if used:
                clause_item = "AND " + clause_item

            clause += clause_item
            used = True
        if object_name is not None:
            clause_item = f"o.name = '{object_name}' "
            if used:
                clause_item = "AND " + clause_item

            clause += clause_item
            used = True

        if used:
            query = query.replace("[WHERE_CLAUSE]", clause)
        else:
            query = query.replace("[WHERE_CLAUSE]", '')

        return graph.run(query).to_data_frame()
    
    def execute_question(self, question, ner_model, relations, graph):
        ##--stage 1: detect entities--##
        entities = self.get_entities(question, model=ner_model)

        true_relation_types = []
        ##--stage 2: detect relations--##
        for candidate_relation_type, candidate_relations_set in relations.items():
            if any(relation in question for relation in candidate_relations_set):
                true_relation_types.append(candidate_relation_type)

        ##--stage 3: executing queries--##
        dfs = []
        for entity in entities:
            for rel_type in true_relation_types:
                dfs.append(self.query(graph, subject_name=entity, relation_type=rel_type))
                dfs.append(self.query(graph, object_name=entity, relation_type=true_relation_type))

        return pd.concat(dfs, axis=0) if len(dfs) > 0 else []

    @api(JsonHandler)
    def predict(self, input_dict):
        
        credentials = {
            'host': '####',
            'username': '####',
            'password': '####'
        }
        
        ner = SequenceTagger.load('ner')
        relations = json.loads(self.artifacts.relations)
        graph = Graph(credentials['host'], auth=(credentials['username'], credentials['password']))
        question = input_dict['question']
        
        dataframe = self.execute_question(question, ner, relations, graph)
        return dataframe.to_json(orient='records') if len(dataframe) != 0 else []
        
        

Overwriting qa_service.py


In [58]:
from qa_service import QAService

In [59]:
qa = QAService.pack(relations=relations_str)

In [73]:
qa.predict({'question': "Where Sundar Pichai is CEO at ?"})

2019-10-13 15:50:13,264 loading file /root/.flair/models/en-ner-conll03-v0.4.pt


'[{"object":"Larry Page","relation":"heads","subject":"Sundar Pichai"},{"object":"CEO","relation":"heads","subject":"Sundar Pichai"}]'