In [1]:
import argparse
import os.path

import coreferee
import networkx as nx
from nltk import sent_tokenize
import pandas as pd
from pyvis.network import Network
import spacy
from spacy.matcher import DependencyMatcher
from spacy.pipeline import merge_entities


def coref_resolution(text):
    """
    This function runs through each text and replaces all coreferences within each text.

    Arguments:
    text (str): input text

    Returns:
    coref_doc (str): output text with all coreferences resolved and replaced.
    """
    nlp = spacy.load('en_core_web_trf')
    nlp.add_pipe('merge_entities')
    nlp.add_pipe('coreferee')

    doc=nlp(text)
    coref_doc=[]
    for token in doc:
        tok = doc._.coref_chains.resolve(token) or token       
        tok_text = ' and '.join([t.text for t in tok]) if(isinstance(tok, list)) else tok.text
        coref_doc.append(tok_text)
        coref_doc=[" ".join(str(item) for item in coref_doc)]
    return (" ".join(coref_doc))


In [2]:
def get_triples(sent):
    """
    This function takes in a subject-verb-object pattern and returns all matches
    in the form of SVO triples within the input sentence. A dictionary is then created 
    to store all entity-pairs and their corresponding relations in a dataframe.

    Arguments:
    sent (str): input sentence

    Returns:
    kg_df (dataframe): dataframe of all triples found within the input sentence
    """
    
    sent=coref_resolution(text)
    nlp = spacy.load("en_core_web_trf")
    nlp.add_pipe('coreferee')
    nlp.add_pipe("merge_noun_chunks")
    
    doc=nlp(sent)
    matcher = DependencyMatcher(nlp.vocab)
    pattern = [
    {"RIGHT_ID": "predicate", "RIGHT_ATTRS": {"POS": "VERB"}},
    {"LEFT_ID": "predicate", "REL_OP": ">", "RIGHT_ID": "subject", 
     "RIGHT_ATTRS": {"DEP": "nsubj"}},
    {"LEFT_ID": "predicate", "REL_OP": ">", "RIGHT_ID": "object", 
     "RIGHT_ATTRS": {"DEP": "dobj"}},
    ]
    matcher.add("SVO", [pattern])

    matches = matcher(doc)
  
    match_ids=[]; token_ids=[];
        
    for i in range(len(matches)):
        match_ids.append(matches[i][0])
        token_ids.append(matches[i][1])

    triple_dict = {"source":[], "relation":[], "target":[]}
    
    for i in range(len(token_ids)):  
        triple_dict["source"].append(doc[token_ids[i][1]].text)
        triple_dict["relation"].append(doc[token_ids[i][0]].text)
        triple_dict["target"].append(doc[token_ids[i][2]].text)
 
    kg_df = pd.DataFrame(triple_dict)
    kg_df.drop_duplicates(inplace=True)
    return kg_df


In [3]:
with open('A Charlie Brown Christmas.txt', 'r') as f:
    text = f.read().replace('\n', '')

kg_df = get_triples(text)
kg_df



Unnamed: 0,source,relation,target
0,the program,made,program debut
1,Charlie Brown,direct,a neighborhood Christmas play
2,Linus,tells,Charlie Brown
3,the producers,went,an unconventional route
4,soundtrack,features,a jazz score
5,soundtrack absence,led,both the producers
6,A Charlie Brown Christmas,received,high ratings
7,A Charlie Brown Christmas success,paved,the way
8,A Charlie Brown Christmas jazz soundtrack,achieved,commercial success
9,ABC,holds,the rights


In [11]:
def dashboard(kg_df):
    """
    This function takes in a dataframe of SVO triples and visualises a knowledge graph
    in HTML, showing all nodes (subject-object pairs) and edges (relations).

    Arguments:
    kg_df (dataframe): dataframe of triples generated from text file

    Returns:
    net (network): a network graph visualisation of all connecting nodes and edges
    """

    source_list = kg_df["source"].values.tolist()
    relation_list = kg_df["relation"].values.tolist()
    target_list = kg_df["target"].values.tolist()

    net = Network()
    G = nx.from_pandas_edgelist(kg_df, 'source', 'target', edge_attr=True, create_using=nx.MultiDiGraph)
    #create_using=nx.Graph())
    net.show_buttons(filter_=['nodes', 'edges'])

    for node in G.nodes():
        net.add_node(node)

    for edge in G.edges():
        net.add_edge(edge[0], edge[1], label=relation_list[target_list.index(str(edge[1]))])
        #[source_list.index(str(edge[0]))])

    #return net
    net.save_graph('nx.html')

    return net.show('nx.html')
    



In [12]:
net=dashboard(kg_df)
net