In [1]:
from bert import QA
import os

import json
import numpy as np
import tensorflow as tf
import matplotlib
import subprocess

matplotlib.use('Agg')

import networkx as nx
import json
import argparse
import matplotlib.pyplot as plt
import random
import string

import os
from nltk import tokenize
from nltk.corpus import stopwords

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [53]:
data_dir = "../../../data"
# filename = 'zelda-botw'
# filename = 'resolved_zelda_botw'
# filename = 'rapunzel'
# filename = 'resolved_rapunzel'
filename = '01_the_fellowship_of_the_ring'
# filename = '02_the_two_towers'
# filename = '03_the_return_of_the_king'
# filename = 'little-red-riding-hood'

In [3]:
def readGraph(file):
    with open(file, 'r') as fp:
        # G = nx.parse_edgelist(, nodetype = int)

        dump = fp.read()
        j = json.loads(dump)
        locs = j.keys()
        objs = []
        for loc in j.values():
            objs += loc['objects']
    return locs, objs, []

location_primer = "Where is the location in the story?"
character_primer = "Who is somebody in the story?"
object_primer = "What is an object in the story?"

loc2loc_templates = ["What location is next to {} in the story?"]

loc2obj_templates = ["What is in {} in the story?", ]
obj2loc_templates = ["What location is {} in the story?", ]

loc2char_templates = ["Who is in {} in the story?", ]
char2loc_templates = ["What location is {} in the story?", ]

obj2char_templates = ["Who has {} in the story?", ]

articles = set(stopwords.words('english'))

In [35]:
class World:
    def __init__(self, args):
        '''
        Initialize the world KG with the given entities and relations, and set up arguments.
        params:
            `locs`: list of locations
            `chars`: list of characters
            `objs`: list of objects
            `relations`: list of relations
            `args`: arguments
        '''
        self.graph = nx.Graph()  # create base NX graph
        self.current_graph = nx.Graph()  # create context NX graph

        self.entities = {  # dictionary of location, object and character entites
            'locations': set(),
            'objects': set(),
            'characters': set()
        }
        self.relations = []  # list of relations
        self.plots = []  # list of plots

        random.seed(args.seed)  # set seed for reproducibility

        self.nsamples = args.nsamples
        self.write_sfdp = args.write_sfdp
        self.context_lines = args.context_lines  # Number of lines of story to read at once
        
        self.load_cutoffs(args.cutoffs)  # load cutoffs from parameter
        self.load_story(args.input_text)  # load story from parameter 
        self.next_context()  # load first context lines      

        self.model = QA('model/albert-large-squad')


    def load_cutoffs(self, cutoffs):
        '''
        Load cutoffs from parameter.
        params:
            `cutoffs`: either a string ['`fairy`', '`mystery`'] or 
            a list of 3 cutoff decimals for [character, location, object] respectively.
        '''
        if cutoffs == 'fairy':
            self.cutoffs = [6.5, -7, -5]  # fairy
        elif cutoffs == 'mystery':
            self.cutoffs = [3.5, -7.5, -6]  # mystery
        else:
            self.cutoffs = [float(i) for i in cutoffs.split()]
            assert len(self.cutoffs) == 3


    def load_story(self, input_text):
        '''
        Load story from parameter.
        params:
            `input_text`: path to story text file
        '''
        with open(input_text, 'r') as f:
            doc = f.readlines()
            doc = ' '.join([x.strip() for x in doc])

        self.story = tokenize.sent_tokenize(doc)  # split into sentences
        self.remaining_story = self.story
    
    def next_context(self):
        '''
        Load the next `context_lines` lines of the story.
        '''
        if self.context_lines == 'all':
            self.context = ' '.join(self.remaining_story)
            self.remaining_story = []
        else:
            self.context = ' '.join(self.remaining_story[:self.context_lines])
            self.remaining_story = self.remaining_story[self.context_lines:]
        self.unmasked_context = self.context


    def is_connected(self):
        '''
        Check if the current world graph is connected.
        '''
        return len(list(nx.connected_components(self.graph))) == 1


    def query(self, context, query, nsamples=10, cutoff=8):
        '''
        Query the model for the top `nsamples` candidates for the given `query`.
        params:
            `query`: query string
            `nsamples`: maximum number of candidates to return
            `cutoff`: cutoff for the model
        '''
        return self.model.predictTopK(context, query, nsamples, cutoff)


    def generateNeighbors(self, nsamples=100):
        '''
        Retrieve relation candidates for each entity in the world graph.
        params:
            `nsamples`: maximum number of candidates to return
        '''
        self.candidates = {}
        for u in self.current_graph.nodes:
            self.candidates[u] = {}
            if self.current_graph.nodes[u]['type'] == "location":
                self.candidates[u]['location'] = self.query(self.context, random.choice(loc2loc_templates).format(u), nsamples, self.cutoffs[1])
                self.candidates[u]['object'] = self.query(self.context, random.choice(loc2obj_templates).format(u), nsamples, self.cutoffs[2])
                self.candidates[u]['character'] = self.query(self.context, random.choice(loc2char_templates).format(u), nsamples,self.cutoffs[0])
            if self.current_graph.nodes[u]['type'] == "object":
                self.candidates[u]['location'] = self.query(self.context, random.choice(obj2loc_templates).format(u), nsamples, self.cutoffs[1])
                self.candidates[u]['character'] = self.query(self.context, random.choice(obj2char_templates).format(u), nsamples, self.cutoffs[0])
            if self.current_graph.nodes[u]['type'] == "character":
                self.candidates[u]['location'] = self.query(self.context, random.choice(char2loc_templates).format(u), nsamples, self.cutoffs[1])


    def relatedness(self, u, v):
        '''
        Compute the relatedness between two entities. Calculated as :math:`P(x,u) = [p(x,u)+o(u,x)]/2`
        params:
            `u`: entity 1
            `v`: entity 2
        '''
        s = self.get_rel_prob(u, v) 
        s += self.get_rel_prob(v, u)
        return s


    def get_rel_prob(self, u, v):
        s = 0
        v_type = self.current_graph.nodes[v]['type']
        if u not in self.candidates or v_type not in self.candidates[u]:
            return s
        u2v, probs = self.candidates[u][v_type]
        if u2v is not None:
            for c, p in zip(u2v, probs):
                a = set(c.text.split()).difference(articles)
                b = set(v.split()).difference(articles)

                # find best intersect
                best_intersect = 0
                for x in self.current_graph.nodes:
                    xx = set(x.split()).difference(articles)
                    best_intersect = max(best_intersect, len(a.intersection(xx)))

                # increment if answer is best match BoW
                if len(a.intersection(b)) == best_intersect:
                    s += len(a.intersection(b)) * p
        return s


    def extractEntity(self, query, threshold=0.05, cutoff=0):
        '''
        Extract all entities from a query result.
        params:
            `query`: query string
            `threshold`: minimum probability for a candidate to be considered
            `cutoff`: cutoff for the model
        '''
        preds, probs = self.query(self.context, query, self.nsamples, cutoff)

        if preds is None:
            print("NO ANSWER FOUND")
            return None, 0

        for pred, prob in zip(preds, probs):
            t = pred.text
            p = prob
            print('> ', t, p)
            if len(t) < 1:
                continue
            if p > threshold and "MASK" not in t:
                # find a more minimal candidate if possible
                for pred, prob in zip(preds, probs):
                    if t != pred.text and pred.text in t and prob > threshold and len(pred.text) > 2:
                        t = pred.text
                        p = prob
                        break

                t = t.strip(string.punctuation)
                remove = t

                # take out leading articles for cleaning
                words = t.split()
                if words[0].lower() in articles:
                    remove = " ".join(words[1:])
                    words[0] = words[0].lower()
                    t = " ".join(words[1:])
                if remove.strip() == '':
                    print("NO ANSWER FOUND")
                    return None, 0
                print(remove)

                self.context = self.context.replace(remove, '[MASK]').replace('  ', ' ').replace(' .', '.')
                return t, p

        return None, 0


    def generate(self, filename=None):
        '''
        Generate the world graph from the entire story by parsing it # `context_lines` at a time.
        Extracts entities, relations between entities and potential plots.
        params:
            `filename`: path of file to save extracted entities as (JSON)
        '''
        # set initial threshold
        threshold = 0.05

        while len(self.context) != 0:
            print(f'Current context: {self.context}')

            # add locations
            locs = self.extract('locations', location_primer, threshold, self.cutoffs[1])
            # add objects
            objs = self.extract('objects', object_primer, threshold, self.cutoffs[2])
            # add characters
            chars = self.extract('characters', character_primer, threshold, self.cutoffs[0])
            
            self.current_graph.clear()
            self.current_graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled")
            self.current_graph.add_nodes_from(objs, type='object', fillcolor="white", style="filled")
            self.current_graph.add_nodes_from(chars, type='character', fillcolor="orange", style="filled")
            
            self.context = self.unmasked_context
            self.autocomplete()
            
            self.entities['locations'].update(locs)
            self.entities['objects'].update(objs)
            self.entities['characters'].update(chars)
            self.next_context()
            print('-' * 20 + '\n')

        self.entities['locations'] = list(self.entities['locations'])
        self.entities['objects'] = list(self.entities['objects'])
        self.entities['characters'] = list(self.entities['characters'])
        self.entities['relations'] = self.relations
        self.entities['plots'] = self.plots
        if filename is not None:
            with open(filename, 'w') as f:
                json.dump(self.entities, f, indent=4, sort_keys=False)


    def extract(self, entity_type, primer, threshold, cutoff):
        entity_list = []
        print('-' * 5 + f'\t{entity_type}\t' + '-' * 5)
        t, _ = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            entity_list.append(t)
            t, _ = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        print('-' * 20 + '\n')
        self.context = self.unmasked_context
        return entity_list


    def autocomplete(self):
        '''
        Autocomplete the current contextual graph by adding relations between entities.
        '''
        self.generateNeighbors(self.nsamples)
        print('-' * 5 + '\trelations\t' + '-' * 5)
        entities = list(self.current_graph.nodes)
        best = (-1, '', '')

        for u in entities:
            u_type = self.current_graph.nodes[u]['type']
            # print(f'Searching relations for {u}')
            for v in list(self.current_graph.nodes):
                v_type = self.current_graph.nodes[v]['type']
                if u == v or (u_type != 'location' and u_type == v_type):
                    continue
                uvrel = self.relatedness(u, v)
                best = max(best, (uvrel, v, v_type))
            
            s, v, v_type = best

            if s <= 0:
                # print(f'No relations found for {u}')
                continue
                    
            if u_type == 'location':
                if v_type == 'location':
                    rel_type = "connected to"
                else:
                    rel_type = "located in"
            else:
                rel_type = "has"
            
            self.check_and_add_world_relation(v, u, rel_type)            


    def check_and_add_world_relation(self, v, u, rel_type):
        '''
        Check if the relation is already present in the graph. 
        If present, add to plot graph, otherwise normally add to world graph
        '''
        u_type = self.current_graph.nodes[u]['type']
        v_type = self.current_graph.nodes[v]['type']

        if u not in self.graph:
            self.graph.add_node(u, type=u_type)
        if v not in self.graph:
            self.graph.add_node(v, type=v_type)

        if u_type == 'location':
            if v_type == 'location':
                # if both are location, just add it the relation
                rel_type = 'connected to'
                rel_triplet = (v, rel_type, u)
                print(f'Adding relation: {rel_triplet}')
                self.relations.append(rel_triplet)
                self.graph.add_edge(v, u, label=rel_type)
            else:
                # u is location, v is char/obj
                rel_type = 'located in'
                rel_triplet = (v, rel_type, u)
                print(f'Adding relation: {rel_triplet}')
                for c_node in self.graph[v]:
                    existing_label = self.graph[v][c_node]['label']
                    if existing_label == rel_type:
                        print(f'Found existing relation: {v}, {rel_type}, {c_node}')
                        if c_node != u:
                            # if v is 'located in' some other location (c_node) before this point,
                            # v's location is changed, put it as an event
                            print(f'Updating new relation: {rel_triplet}')
                            self.plots.append({'before': (v, existing_label, c_node), 'after': rel_triplet})
                self.graph.add_edge(v, u, label=rel_type)
                self.relations.append(rel_triplet)
        
        elif u_type == 'charater' and v_type == 'object':
            rel_type = 'has'
            rel_triplet = (u, rel_type, v)
            print(f'Adding relation: {rel_triplet}')
            for c_node in self.graph[v]:
                existing_label = self.graph[v][c_node]['label']
                if existing_label == rel_type:
                    print(f'Found existing relation: {c_node}, {rel_type}, {v}')
                    if c_node != u:
                        # if some other char (c_node) has v before this point, 
                        # v's ownership has changed, put it as event
                        print(f'Updating new relation: {rel_triplet}')
                        self.plots.append({'before': (c_node, existing_label, v), 'after': rel_triplet})
            self.graph.add_edge(v, u, label=rel_type)
            self.relations.append(rel_triplet)
        
        elif u_type == 'object' and v_type == 'character':
            rel_type = 'has'
            rel_triplet = (v, rel_type, u)
            print(f'Adding relation: {rel_triplet}')
            for c_node in self.graph[u]:
                existing_label = self.graph[u][c_node]['label']
                if existing_label == rel_type:
                    print(f'Found existing relation: {v}, {rel_type}, {c_node}')
                    if c_node != u:
                        # if some other char (c_node) has u before this point, 
                        # u's ownership has changed, put it as event
                        print(f'Updating new relation: {rel_triplet}')
                        self.plots.append({'before': (v, existing_label, c_node), 'after': rel_triplet})
            self.graph.add_edge(v, u, label=rel_type)
            self.relations.append(rel_triplet)
            
        elif v_type == 'location':
            # u is char/obj, v is location
            rel_type = 'located in'
            rel_triplet = (u, rel_type, v)
            print(f'Adding relation: {rel_triplet}')
            for c_node in self.graph[u]:
                existing_label = self.graph[u][c_node]['label']
                if existing_label == rel_type:
                    print(f'Found existing relation: {u}, {rel_type}, {c_node}')
                    if c_node != v:
                        # if u is 'located in' some other location (c_node) before this point,
                        # u's location is changed, put it as an event
                        print(f'Updating new relation: {rel_triplet}')
                        self.plots.append({'before': (u, existing_label, c_node), 'after': rel_triplet})
            self.graph.add_edge(v, u, label=rel_type)
            self.relations.append(rel_triplet)


    def export(self, filename="graph.dot"):
        nx.nx_pydot.write_dot(self.graph, filename)
        nx.write_gml(self.graph, "graph.gml", stringizer=None)


    def draw(self, filename="./graph.svg"):
        self.export()

        if self.write_sfdp:
            cmd = "sfdp -x -Goverlap=False -Tsvg graph.dot".format(filename)
            returned_value = subprocess.check_output(cmd, shell=True)
            with open(filename, 'wb') as f:
                f.write(returned_value)
            cmd = "inkscape -z -e {}.png {}.svg".format(filename[:-4], filename[:-4])
            returned_value = subprocess.check_output(cmd, shell=True)
        else:
            nx.draw(self.graph, with_labels=True)
            plt.savefig(filename[:-4] + '.png')

In [54]:

class Args:
    def __init__(self, args):
        self.input_text = args['input_text']
        self.seed = args['seed']
        self.nsamples = args['nsamples']
        self.cutoffs = args['cutoffs']
        self.write_sfdp = args['write_sfdp']
        self.random = args['random']
        self.context_lines = args['context_lines']
                
args = Args({
    'input_text': os.path.join(data_dir, f'{filename}.txt'),
    'seed' : 0,
    'nsamples' : 14,
    'cutoffs' : '12 15 13',
    'write_sfdp': False,
    'random': False,
    'context_lines': 10
})

In [55]:
world = World(args)
world.generate(os.path.join(f"./outputs/{filename}.json"))

Current context: Long ago, twenty rings existed: three for elves, seven for dwarves, nine for men, and one made by the Dark Lord Sauron, in Mordor, which would rule all the others. Sauron poured all his evil and his will to dominate into this ring. An alliance of elves and humans resisted Sauron’s ring and fought against Mordor. They won the battle and the ring fell to Isildur, the son of the king of Gondor, but just as he was about to destroy the ring in Mount Doom, he changed his mind and held on to it for himself. Later he was killed, and the ring fell to the bottom of the sea. The creature Gollum discovered it and brought it to his cave. Then he lost it to the hobbit Bilbo Baggins. The movie cuts to an image of the hobbits’ peaceful Shire years later, where the wizard Gandalf has come to celebrate Bilbo’s 111th birthday. The party is an extravagant occasion with fireworks and revelry, and Bilbo entertains children with tales of his adventures. In the middle of a rambling speech, ho

In [50]:
for n in world.graph.nodes:
    print(f'"{n}" is {world.graph.nodes[n]["type"]}')
    for a in world.graph.adj[n]:
    	print(f'\t {world.graph.adj[n][a]["label"]} "{a}"')

"top of the Great Plateau" is location
	 connected to "Shrine of Resurrection"
	 connected to "campfire"
	 connected to "slumber"
	 connected to "ruined country"
	 connected to "Shrines"
	 connected to "Link"
	 connected to "particular the Oman Au Shrine"
	 connected to "Shrine"
	 connected to "leaves"
	 connected to "Owa"
	 connected to "Keh Namut"
	 connected to "Old Man's Paraglider"
	 connected to "Spirit Orbs"
	 connected to "Ja Baij"
	 located in "voice"
	 located in "Old Man promises"
	 located in "Old Man"
	 located in "awakens"
	 located in "way"
"Shrine of Resurrection" is location
	 connected to "top of the Great Plateau"
	 connected to "Hyrule Kingdom"
	 connected to "Great Plateau"
	 connected to "Hyrule Castle"
	 connected to "Shrine of Resurrection"
	 connected to "ruined"
	 connected to "Zelda"
	 connected to "Link"
	 connected to "Ganon"
	 connected to "Hyruleans"
	 connected to "Calamity"
	 connected to "Rhoam"
	 connected to "Hylian Champion"
	 connected to "waste"
	

In [None]:
g = nx.Graph()
g.add_node('a')
g.add_node('b')
g.add_node('c', type='gg')
g.add_edge('a','b', type='conn')
g.add_edge('a','c', type='conn')

In [None]:
world.graph.nodes['bottom of the sea']