# AskBERT

Using [AskBERT](https://github.com/rajammanabrolu/WorldGeneration/tree/master/neural-based/KG-extraction), [BERT-SQuAD](https://github.com/kamalkraj/BERT-SQuAD). 

Paper: Ammanabrolu, Prithviraj, Wesley Cheung, Dan Tu, William Broniec, and Mark Riedl. 2020. “Bringing Stories Alive: Generating Interactive Fiction Worlds.” Proceedings of the AAAI Conference on Artificial Intelligence and Interactive Digital Entertainment 16 (1): 3–9. https://doi.org/10.1609/aiide.v16i1.7400.


In [None]:
#%pip install nltk

from bert import QA
from nltk import tokenize
nltk.download('punkt')## Inference on AskBERT

In [1]:
data_dir = "../../../data"

filename = 'zelda-botw'
# filename = 'resolved_zelda_botw'
# filename = 'rapunzel'
# filename = 'resolved_rapunzel'

In [None]:
# import nltk
import os
# nltk.download('punkt')
from nltk import tokenize

with open(os.path.join(data_dir, f"{filename}.txt")) as f:
    doc = f.readlines()
    doc = ' '.join([x.strip() for x in doc])
    doc = tokenize.sent_tokenize(doc) # split into sentences

for s in doc:
    print(s)

In [None]:
from bert import QA
import os

model = QA('model/albert-large-squad')

In [None]:
with open(os.path.join(data_dir, "resolved_zelda_botw.txt")) as f:
    doc = f.read()

In [None]:
def print_query_results(doc, query):
    answer = model.predict(doc, query)
    for preds in answer[0]:
        print(preds)
    return answer

In [None]:
doc = "After escaping the confines of the Great Plateau, Link is directed to meet the wise Sheikah elder Impa, and learn about the Guardians and Divine Beasts: 10,000 years prior the Guardians and Divine Beasts were created and successfully used by another Hero and another Princess to defeat a great evil known as the Calamity Ganon. But throughout the ages, knowledge about the Guardians and Divine Beasts was lost until excavations in the ruined country of Hyrule Kingdom brought the Guardians and Divine Beasts to light once more, coinciding with the expected return of a great evil known as the Calamity Ganon a hundred years ago. Guardians were reactivated and four Champions were chosen to control Divine Beasts: the Zora princess Mipha, the Goron warrior Daruk, the Gerudo chief Urbosa, and the Rito archer Revali. All the while, Zelda, who is the daughter of King Rhoam was unsuccessfully trying to gain access to Zelda, who is the daughter of King Rhoam's own prophesied powers, accompanied on Zelda, who is the daughter of King Rhoam's quests by Zelda, who is the daughter of King Rhoam's knight, the Hylian Champion Link. When a great evil known as the Calamity Ganon ultimately attacked, a great evil known as the Calamity Ganon devastated the ruined country of Hyrule Kingdom by taking control of the ancient machines and turning the ancient machines against the Hyruleans. As a last resort, Zelda, who is the daughter of King Rhoam was able to place Link in the Shrine of Resurrection and use Zelda, who is the daughter of King Rhoam's awoken sealing powers to trap Zelda, who is the daughter of King Rhoam with a great evil known as the Calamity Ganon in Hyrule Castle."

In [None]:
import random

In [None]:
loc2loc_templates = "What location is next to {} in the story?"

loc2obj_templates = "What object is in {} in the story?" 
obj2loc_templates = "What location is the object {} in the story?"

loc2char_templates = "Who is in {} in the story?"
char2loc_templates = "What location is {} in the story?"

In [None]:
char_template = "Who is somebody in the story?"
loc_template = "Where is the location in the story?"
obj_template = "What is an object in the story?"

In [None]:
characters = ['Zelda', 'Link', 'daughter', 'King of Hyrule, King Rhoam', 'Rhoam', 'spirit of the deceased King', 'voice', 'Old Man', 'Sheikah', 'Impa', 'King', 'Revali', 'Daruk', 'Urbosa', 'Mipha']
locations = ['Hyrule Kingdom', 'Great Plateau', 'Hyrule Castle', 'ruined country', 'Temple of Time', 'Shrine of Resurrection', 'campfire']
objects = ['Spirit Orbs', "Old Man's Paraglider"]

In [None]:
selected_loc = random.choice(locations)
print(selected_loc)
loc_query = print_query_results(doc, obj2loc_templates.format('shrines'))
print(loc_query[1])

In [None]:
char_query = print_query_results(doc, 'Where is ruined country?')
print(char_query[1])

In [None]:
a_query = print_query_results(doc, 'What location is paraglider in the story?')
print(a_query[1])

In [None]:
from bert import QA
import os

# !/usr/bin/env python3

import json
import numpy as np
import tensorflow as tf
import matplotlib
import subprocess

matplotlib.use('Agg')

import networkx as nx
import json
import argparse
import matplotlib.pyplot as plt
import random
import string


def readGraph(file):
    with open(file, 'r') as fp:
        # G = nx.parse_edgelist(, nodetype = int)

        dump = fp.read()
        j = json.loads(dump)
        locs = j.keys()
        objs = []
        for loc in j.values():
            objs += loc['objects']
    return locs, objs, []


loc2loc_templates = ["What location is next to {} in the story?"]
loc2obj_templates = ["What is in {} in the story?", ]
loc2char_templates = ["Who is in {} in the story?", ]

obj2loc_templates = ["What location is {} in the story?", ]
obj2char_templates = ["Who has {} in the story?", ]

char2loc_templates = ["What location is {} in the story?", ]

conjunctions = ['and', 'or', 'nor']
articles = ["the", 'a', 'an', 'his', 'her', 'their', 'my', 'its', 'those', 'these', 'that', 'this', 'the']
pronouns = ["He", "She", "he", "she"]


class World:
    def __init__(self, locs, chars, objs, relations, args):
        self.graph = nx.Graph()
        self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled")
        self.graph.add_nodes_from(chars, type='character', fillcolor="orange", style="filled")
        self.graph.add_nodes_from(objs, type='object', fillcolor="white", style="filled")
        self.graph.add_edges_from(relations)

        self.locations = {v for v in locs}
        self.objects = {v for v in objs}
        self.characters = {v for v in chars}
        
        self.relations = relations

        self.context_lines = 8

        self.args = args
        self.cutoffs = self.args.cutoffs

        if self.args.cutoffs == 'fairy':
            self.cutoffs = [6.5, -7, -5]  # fairy
        elif self.args.cutoffs == 'mystery':
            self.cutoffs = [3.5, -7.5, -6]  # mystery
        else:
            self.cutoffs = [float(i) for i in self.args.cutoffs.split()]
            assert len(self.cutoffs) == 3

        with open(args.input_text) as f:
            self.input_text = f.read()
        
        self.story = self.input_text

        print(self.story)

        self.model = QA('model/albert-large-squad')

    def is_connected(self):
        return len(list(nx.connected_components(self.graph))) == 1

    def query(self, query, nsamples=10, cutoff=8):
        return self.model.predictTopK(self.input_text, query, nsamples, cutoff)

    def generateNeighbors(self, nsamples=100):
        self.candidates = {}
        for u in self.graph.nodes:
            self.candidates[u] = {}
            if self.graph.nodes[u]['type'] == "location":
                self.candidates[u]['location'] = self.query(random.choice(loc2loc_templates).format(u), nsamples, self.cutoffs[1])
                self.candidates[u]['object'] = self.query(random.choice(loc2obj_templates).format(u), nsamples, self.cutoffs[2])
                self.candidates[u]['character'] = self.query(random.choice(loc2char_templates).format(u), nsamples,self.cutoffs[0])
            if self.graph.nodes[u]['type'] == "object":
                self.candidates[u]['location'] = self.query(random.choice(obj2loc_templates).format(u), nsamples, self.cutoffs[1])
                self.candidates[u]['character'] = self.query(random.choice(obj2char_templates).format(u), nsamples, self.cutoffs[0])
            if self.graph.nodes[u]['type'] == "character":
                self.candidates[u]['location'] = self.query(random.choice(char2loc_templates).format(u), nsamples, self.cutoffs[1])

    def relatedness(self, u, v, u_type='location', v_type='location'):

        s = 0
        u2v, probs = self.candidates[u][v_type]

        if u2v is not None:
            for c, p in zip(u2v, probs):
                a = set(c.text.split()).difference(articles)
                b = set(v.split()).difference(articles)

                # find best intersect
                best_intersect = 0
                for x in self.graph.nodes:
                    xx = set(x.split()).difference(articles)
                    best_intersect = max(best_intersect, len(a.intersection(xx)))

                # increment if answer is best match BoW
                if len(a.intersection(b)) == best_intersect:
                    s += len(a.intersection(b)) * p

                # naive method
                # s += len(a.intersection(b)) * p

        v2u, probs = self.candidates[v][u_type]

        if v2u is not None:
            for c, p in zip(v2u, probs):
                a = set(c.text.split()).difference(articles)
                b = set(u.split()).difference(articles)

                # find best intersect
                best_intersect = 0
                for x in self.graph.nodes:
                    xx = set(x.split()).difference(articles)
                    best_intersect = max(best_intersect, len(a.intersection(xx)))

                # increment if answer is best match BoW
                if len(a.intersection(b)) == best_intersect:
                    s += len(a.intersection(b)) * p

                # naive method
                # s += len(a.intersection(b)) * p

        return s

    def extractEntity(self, query, threshold=0.05, cutoff=0):
        preds, probs = self.query(query, self.args.nsamples, cutoff)

        if preds is None:
            print("NO ANSWER FOUND")
            return None, 0

        for pred, prob in zip(preds, probs):
            t = pred.text
            p = prob
            print('> ', t, p)
            if len(t) < 1:
                continue
            if p > threshold and "MASK" not in t:

                # find a more minimal candidate if possible
                # TODO: IS IT REALLY NEEDED?
                for pred, prob in zip(preds, probs):
                    if t != pred.text and pred.text in t and prob > threshold and len(pred.text) > 2:
                        t = pred.text
                        p = prob
                        break

                t = t.strip(string.punctuation)
                remove = t

                # take out leading articles for cleaning
                words = t.split()
                if words[0].lower() in articles:
                    remove = " ".join(words[1:])
                    words[0] = words[0].lower()
                    t = " ".join(words[1:])
                print(remove)

                self.input_text = self.input_text.replace(remove, '[MASK]').replace('  ', ' ').replace(' .', '.')
                return t, p
            # else:
            # # find a more minimal candidate if possible
            #     for pred, prob in zip(preds, probs):
            #         if prob > threshold and "MASK" not in pred.text and len(pred.text) > 2 and pred.text in t:
            #             t = pred.text.strip(string.punctuation)
            #             p = prob
            #             self.input_text = self.input_text.replace(t, '[MASK]').replace('  ', ' ').replace(' .', '.')
            #             print(t, p)
            #             return t, p

        return None, 0

    def generate(self, filename="entities.json"):

        locs = []
        objs = []
        chars = []

        # set thresholds/cutoffs
        threshold = 0.05

        # save input text
        tmp = self.input_text[:]

        # add chars
        print("=" * 20 + "\tcharacters\t" + "=" * 20)
        self.input_text = tmp
        primer = "Who is somebody in the story?"
        cutoff = self.cutoffs[0]
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            chars.append(t)
            t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)

        # add locations
        print("=" * 20 + "\tlocations\t" + "=" * 20)
        self.input_text = tmp
        primer = "Locations?"
        cutoff = self.cutoffs[1]
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            locs.append(t)
            t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)

        # add objects
        print("=" * 20 + "\tobjects\t\t" + "=" * 20)
        self.input_text = tmp
        primer = "What is an object in the story?"
        cutoff = self.cutoffs[2]
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            objs.append(t)
            t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        self.input_text = tmp

        self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled")
        self.graph.add_nodes_from(chars, type='character', fillcolor="orange", style="filled")
        self.graph.add_nodes_from(objs, type='object', fillcolor="white", style="filled")
        
        with open(filename, 'w') as f:
            json.dump({'characters':chars, 'locations':locs, 'objects':objs}, f, indent=4, sort_keys=False)
        self.autocomplete()

    def autocomplete(self):

        self.generateNeighbors(self.args.nsamples)

        print("=" * 20 + "\trelations\t" + "=" * 20)
        while not self.is_connected():
            components = list(nx.connected_components(self.graph))
            best = (-1, next(iter(components[0])), next(iter(components[1])))

            main = components[0]

            for u in main:
                u_type = self.graph.nodes[u]['type']
                # print(f'Searching relations for {u}')
                for c in components[1:]:
                    for v in c:
                        v_type = self.graph.nodes[v]['type']
                        if u_type != 'location' and u_type == v_type:
                            continue
                        uvrel = self.relatedness(u, v, u_type, v_type)
                        best = max(best, (uvrel, u, v))
                        # print(f'\twith {v} {uvrel}')

            _, u, v = best

            # attach randomly if empty or specified
            if _ == 0 or self.args.random:
                candidates = []
                for c in components[0]:
                    if self.graph.nodes[c]['type'] == 'location':
                        candidates.append(c)
                u = random.choice(candidates)
            u_type = self.graph.nodes[u]['type']
            v_type = self.graph.nodes[v]['type']
            
            if u_type == 'location':
                if v_type == 'location':
                    rel_type = "connected to"
                else:
                    rel_type = "present in"
            else:
                rel_type = "has"
            
            print("{} {} {}".format(v, rel_type, u))
            self.graph.add_edge(v, u, label=type)
            self.edge_labels[(v, u)] = rel_type

    def export(self, filename="graph.dot"):
        nx.nx_pydot.write_dot(self.graph, filename)
        nx.write_gml(self.graph, "graph.gml", stringizer=None)

    def draw(self, filename="./graph.svg"):
        self.export()

        if args.write_sfdp:
            cmd = "sfdp -x -Goverlap=False -Tsvg graph.dot".format(filename)
            returned_value = subprocess.check_output(cmd, shell=True)
            with open(filename, 'wb') as f:
                f.write(returned_value)
            cmd = "inkscape -z -e {}.png {}.svg".format(filename[:-4], filename[:-4])
            returned_value = subprocess.check_output(cmd, shell=True)
        else:
            nx.draw(self.graph, with_labels=True)
            plt.savefig(filename[:-4] + '.png')

In [None]:
class Args:
	def __init__(self, args):
		self.input_text = args['input_text']
		self.length = args['length']
		self.batch_size = args['batch_size']
		self.temperature = args['temperature']
		self.model_name = args['model_name']
		self.seed = args['seed']
		self.nsamples = args['nsamples']
		self.cutoffs = args['cutoffs']
		self.write_sfdp = args['write_sfdp']
		self.random = args['random']
                
args = Args({
	# 'input_text': os.path.join(data_dir, 'resolved_rapunzel.txt'),
	'input_text': os.path.join(data_dir, 'resolved_zelda_botw.txt'),
	# 'input_text': os.path.join(data_dir, 'zelda-botw.txt'),
	# 'input_text': os.path.join(data_dir, 'rapunzel.txt'),
	# 'input_text': input_text,
	'length': 10,
	'batch_size': 1,
	'temperature': 0.5,
	'model_name' : '117M',
	'seed' : 0,
	'nsamples' : 50,
	'cutoffs' : '11.5 15 12',
	#'cutoffs' : 'fairy',
	'write_sfdp': False,
	'random': False
})

In [None]:
import random
rseed = random.randint(0, 1000000)
print(rseed)
random.seed(0)
# 182664

In [None]:
world = World([], [], [], [], args)
world.generate()

In [None]:
world.draw('test.svg')
world.export('test.dot')