In [1]:
from bert import QA
import os

model = QA('model/albert-large-squad')
data_dir = "../../../data"

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [11]:
import json
import numpy as np
import tensorflow as tf
import matplotlib
import subprocess

matplotlib.use('Agg')

import networkx as nx
import json
import argparse
import matplotlib.pyplot as plt
import random
import string


def readGraph(file):
	with open(file, 'r') as fp:
		# G = nx.parse_edgelist(, nodetype = int)

		dump = fp.read()
		j = json.loads(dump)
		locs = j.keys()
		objs = []
		for loc in j.values():
			objs += loc['objects']
	return locs, objs, []


loc2loc_templates = ["What location is next to {} in the story?"]

loc2obj_templates = ["What is in {} in the story?", ]
obj2loc_templates = ["What location is {} in the story?", ]
obj2char_templates = ["Who has {} in the story?", ]

loc2char_templates = ["Who is in {} in the story?", ]
char2loc_templates = ["What location is {} in the story?", ]
char2obj_templates = ["What does {} have in the story?", ]

conjunctions = ['and', 'or', 'nor']
articles = ["the", 'a', 'an', 'his', 'her', 'their', 'my', 'its', 'those', 'these', 'that', 'this', 'the']
pronouns = [" He ", " She ", " he ", " she "]


class World:
	def __init__(self, locs, objs, relations, args):
		self.graph = nx.Graph()
		self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled")
		self.graph.add_nodes_from(objs, type='object')
		self.graph.add_edges_from(relations)

		self.locations = {v for v in locs}
		self.objects = {v for v in objs}
		self.edge_labels = {}

		self.args = args
		self.cutoffs = self.args.cutoffs

		if self.args.cutoffs == 'fairy':
			self.cutoffs = [6.5, -7, -5]  # fairy
		elif self.args.cutoffs == 'mystery':
			self.cutoffs = [3.5, -7.5, -6]  # mystery
		else:
			self.cutoffs = [float(i) for i in self.args.cutoffs.split()]
			assert len(self.cutoffs) == 3

		with open(args.input_text) as f:
			self.input_text = f.read()
		 
		print(self.input_text)

		self.model = QA('model/albert-large-squad')

	def is_connected(self):
		return len(list(nx.connected_components(self.graph))) == 1

	def query(self, query, nsamples=10, cutoff=8):
		return self.model.predictTopK(self.input_text, query, nsamples, cutoff)

	def generateNeighbors(self, nsamples=100):
		self.candidates = {}
		for u in self.graph.nodes:
			self.candidates[u] = {}
			if self.graph.nodes[u]['type'] == "location":
				self.candidates[u]['location'] = self.query(random.choice(loc2loc_templates).format(u), nsamples, self.cutoffs[1])
				self.candidates[u]['object'] = self.query(random.choice(loc2obj_templates).format(u), nsamples, self.cutoffs[2])
				self.candidates[u]['character'] = self.query(random.choice(loc2char_templates).format(u), nsamples,self.cutoffs[0])
			if self.graph.nodes[u]['type'] == "object":
				self.candidates[u]['location'] = self.query(random.choice(obj2loc_templates).format(u), nsamples, self.cutoffs[1])
				self.candidates[u]['character'] = self.query(random.choice(obj2char_templates).format(u), nsamples, self.cutoffs[0])
			if self.graph.nodes[u]['type'] == "character":
				self.candidates[u]['location'] = self.query(random.choice(char2loc_templates).format(u), nsamples, self.cutoffs[1])
				self.candidates[u]['object'] = self.query(random.choice(char2obj_templates).format(u), nsamples, self.cutoffs[2])

	def relatedness(self, u, v, u_type='location', v_type='location'):

		s = 0
		u2v, probs = self.candidates[u][v_type]

		if u2v is not None:
			for c, p in zip(u2v, probs):
				a = set(c.text.split()).difference(articles)
				b = set(v.split()).difference(articles)

				# find best intersect
				best_intersect = 0
				for x in self.graph.nodes:
					xx = set(x.split()).difference(articles)
					best_intersect = max(best_intersect, len(a.intersection(xx)))

				# increment if answer is best match BoW
				if len(a.intersection(b)) == best_intersect:
					s += len(a.intersection(b)) * p

				# naive method
				# s += len(a.intersection(b)) * p

		v2u, probs = self.candidates[v][u_type]

		if v2u is not None:
			for c, p in zip(v2u, probs):
				a = set(c.text.split()).difference(articles)
				b = set(u.split()).difference(articles)

				# find best intersect
				best_intersect = 0
				for x in self.graph.nodes:
					xx = set(x.split()).difference(articles)
					best_intersect = max(best_intersect, len(a.intersection(xx)))

				# increment if answer is best match BoW
				if len(a.intersection(b)) == best_intersect:
					s += len(a.intersection(b)) * p

				# naive method
				# s += len(a.intersection(b)) * p

		return s

	def extractEntity(self, query, threshold=0.05, cutoff=0):
		preds, probs = self.query(query, self.args.nsamples, cutoff)

		if preds is None:
			print("NO ANSWER FOUND")
			return None, 0

		for pred, prob in zip(preds, probs):
			t = pred.text
			p = prob
			print('> ', t, p)
			if len(t) < 1:
				continue
			if p > threshold and "MASK" not in t:

				# find a more minimal candidate if possible
				# TODO: IS IT REALLY NEEDED?
				for pred, prob in zip(preds, probs):
					if t != pred.text and pred.text in t and prob > threshold and len(pred.text) > 2:
						t = pred.text
						p = prob
						break

				t = t.strip(string.punctuation)
				remove = t

				# take out leading articles for cleaning
				words = t.split()
				if words[0].lower() in articles:
					remove = " ".join(words[1:])
					words[0] = words[0].lower()
					t = " ".join(words[1:])
				print(remove)

				self.input_text = self.input_text.replace(remove, '[MASK]').replace('  ', ' ').replace(' .', '.')
				return t, p
			# else:
			# find a more minimal candidate if possible
			#     for pred, prob in zip(preds, probs):
			#         if prob > threshold and "MASK" not in pred.text and len(pred.text) > 2 and pred.text in t:
			#             t = pred.text.strip(string.punctuation)
			#             p = prob
			#             self.input_text = self.input_text.replace(t, '[MASK]').replace('  ', ' ').replace(' .', '.')
			#             print(t, p)
			#             return t, p

		return None, 0

	def generate(self, filename="entities.json"):

		locs = []
		objs = []
		chars = []

		# chars = [
		# "Zelda",
		# "Link",
		# "daughter",
		# "deceased King of Hyrule",
		# "King Rhoam",
		# "voice",
		# "Old Man",
		# "Revali",
		# "Sheikah",
		# "Impa",
		# "knight",
		# "Daruk"
		# ]
		# locs = [
		# "Hyrule Kingdom",
		# "Great Plateau",
		# "Temple of Time",
		# "campfire"
		# ]
		# objs = [
		# "Spirit Orbs",
		# "Old Man's Paraglider"
		# ]
		
		# set thresholds/cutoffs
		threshold = 0.05

		# save input text
		tmp = self.input_text[:]

		# add chars
		print("=" * 20 + "\tcharacters\t" + "=" * 20)
		self.input_text = tmp
		primer = "Who is somebody in the story?"
		cutoff = self.cutoffs[0]
		t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
		while t is not None and len(t) > 1:
			chars.append(t)
			t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)

		print("=" * 20 + "\tlocations\t" + "=" * 20)

		# add locations
		self.input_text = tmp
		primer = "Locations?"
		cutoff = self.cutoffs[1]
		t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
		while t is not None and len(t) > 1:
			locs.append(t)
			t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)

		print("=" * 20 + "\tobjects\t\t" + "=" * 20)

		# add objects
		self.input_text = tmp
		primer = "What is an object in the story?"
		cutoff = self.cutoffs[2]
		t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
		while t is not None and len(t) > 1:
			objs.append(t)
			t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
		self.input_text = tmp

		self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled")
		self.graph.add_nodes_from(chars, type='character', fillcolor="orange", style="filled")
		self.graph.add_nodes_from(objs, type='object', fillcolor="white", style="filled")

		with open(filename, 'w') as f:
			json.dump({'characters':chars, 'locations':locs, 'objects':objs}, f, indent=4, sort_keys=False)
		self.autocomplete()

	def autocomplete(self):

		self.generateNeighbors(self.args.nsamples)

		print("=" * 20 + "\trelations\t" + "=" * 20)
		while not self.is_connected():
			components = list(nx.connected_components(self.graph))
			best = (-1, next(iter(components[0])), next(iter(components[1])))

			main = components[0]

			for u in main:
				u_type = self.graph.nodes[u]['type']
				# print(f'Searching relations for {u}')
				for c in components[1:]:
					for v in c:
						v_type = self.graph.nodes[v]['type']
						if u_type != 'location' and u_type == v_type:
							continue
						uvrel = self.relatedness(u, v, u_type, v_type)
						best = max(best, (uvrel, u, v))
						# print(f'\twith {v} {uvrel}')

			_, u, v = best

			# attach randomly if empty or specified
			if _ == 0 or self.args.random:
				candidates = []
				for c in components[0]:
					if self.graph.nodes[c]['type'] == 'location':
						candidates.append(c)
				u = random.choice(candidates)
			u_type = self.graph.nodes[u]['type']
			v_type = self.graph.nodes[v]['type']
			
			if u_type == 'location':
				if v_type == 'location':
					rel_type = "connected to"
				else:
					rel_type = "present in"
			elif v_type == 'object':
				rel_type = "held by"	
			else:
				rel_type = "has"
			
			print("{} {} {}".format(v, rel_type, u))
			self.graph.add_edge(v, u, label=type)
			self.edge_labels[(v, u)] = rel_type

	def export(self, filename="graph.dot"):
		nx.nx_pydot.write_dot(self.graph, filename)
		nx.write_gml(self.graph, "graph.gml", stringizer=None)

	def draw(self, filename="./graph.svg"):
		self.export()

		if self.args.write_sfdp:
			cmd = "sfdp -x -Goverlap=False -Tsvg graph.dot".format(filename)
			returned_value = subprocess.check_output(cmd, shell=True)
			with open(filename, 'wb') as f:
				f.write(returned_value)
			cmd = "inkscape -z -e {}.png {}.svg".format(filename[:-4], filename[:-4])
			returned_value = subprocess.check_output(cmd, shell=True)
		else:
			nx.draw(self.graph, with_labels=True)
			plt.savefig(filename[:-4] + '.png')

In [9]:
import random
random.seed(0)

# filename = 'zelda_botw'
# filename = 'resolved_zelda_botw'
filename = 'rapunzel'
# filename = 'resolved_rapunzel'

class Args:
	def __init__(self, args):
		self.input_text = args['input_text']
		self.length = args['length']
		self.batch_size = args['batch_size']
		self.temperature = args['temperature']
		self.model_name = args['model_name']
		self.seed = args['seed']
		self.nsamples = args['nsamples']
		self.cutoffs = args['cutoffs']
		self.write_sfdp = args['write_sfdp']
		self.random = args['random']
                
args = Args({
	'input_text': os.path.join(data_dir, filename + '.txt'),
	'length': 10,
	'batch_size': 1,
	'temperature': 1,
	'model_name' : '117M',
	'seed' : 0,
	'nsamples' : 15,
	'cutoffs' : '11.5 15 12',
	#'cutoffs' : 'fairy',
	'write_sfdp': False,
	'random': False
})

In [12]:
world = World([], [], [], args)
world.generate('./outputs/entities/{}.json'.format(filename))

A lonely couple, who want a child, live next to a walled garden belonging to a sorceress. The wife, experiencing the cravings associated with the arrival of her long-awaited pregnancy, notices some rapunzel, growing in the garden and longs for it. She refuses to eat anything else and gets sick, and the husband begins to fear for her life. One night, her husband breaks into the garden to get some for her. She makes a salad out of it and greedily eats it. It tastes so good that she longs for more. So her husband goes to get some more for her. As he scales the wall to return home, the sorceress catches him and accuses him of theft. He begs for mercy, and she agrees to be lenient, and allows him to take all the rapunzel he wants, on condition that the baby be given to her when it's born. Desperate, he agrees. When his wife has a baby girl, the sorceress takes her to raise as her own and names her "Rapunzel" after the plant her mother craved. She grows up to be the most beautiful child in t

: 