# AskBERT

Using [AskBERT](https://github.com/rajammanabrolu/WorldGeneration/tree/master/neural-based/KG-extraction), [BERT-SQuAD](https://github.com/kamalkraj/BERT-SQuAD). 

Paper: Ammanabrolu, Prithviraj, Wesley Cheung, Dan Tu, William Broniec, and Mark Riedl. 2020. “Bringing Stories Alive: Generating Interactive Fiction Worlds.” Proceedings of the AAAI Conference on Artificial Intelligence and Interactive Digital Entertainment 16 (1): 3–9. https://doi.org/10.1609/aiide.v16i1.7400.


In [None]:
#%pip install nltk

from bert import QA
from nltk import tokenize
nltk.download('punkt')## Inference on AskBERT

In [None]:
import nltk
nltk.download('punkt')
from nltk import tokenize


In [42]:
from bert import QA

In [1]:
input_text = "../../../data/zelda-botw.txt"

In [43]:
model = QA('model/albert-large-squad')

In [39]:
with open("resolved_zelda_botw.txt") as f:
    doc = f.read()
    # doc = tokenize.sent_tokenize(doc)    

In [40]:
def print_query_results(doc, query):
    answer = model.predict(doc, query)
    for preds in answer[0]:
        print(preds)
    return answer

In [None]:
for s in doc:
    print(s)

In [3]:
doc = "After escaping the confines of the Great Plateau, Link is directed to meet the wise Sheikah elder Impa, and learn about the Guardians and Divine Beasts: 10,000 years prior the Guardians and Divine Beasts were created and successfully used by another Hero and another Princess to defeat a great evil known as the Calamity Ganon. But throughout the ages, knowledge about the Guardians and Divine Beasts was lost until excavations in the ruined country of Hyrule Kingdom brought the Guardians and Divine Beasts to light once more, coinciding with the expected return of a great evil known as the Calamity Ganon a hundred years ago. Guardians were reactivated and four Champions were chosen to control Divine Beasts: the Zora princess Mipha, the Goron warrior Daruk, the Gerudo chief Urbosa, and the Rito archer Revali. All the while, Zelda, who is the daughter of King Rhoam was unsuccessfully trying to gain access to Zelda, who is the daughter of King Rhoam's own prophesied powers, accompanied on Zelda, who is the daughter of King Rhoam's quests by Zelda, who is the daughter of King Rhoam's knight, the Hylian Champion Link. When a great evil known as the Calamity Ganon ultimately attacked, a great evil known as the Calamity Ganon devastated the ruined country of Hyrule Kingdom by taking control of the ancient machines and turning the ancient machines against the Hyruleans. As a last resort, Zelda, who is the daughter of King Rhoam was able to place Link in the Shrine of Resurrection and use Zelda, who is the daughter of King Rhoam's awoken sealing powers to trap Zelda, who is the daughter of King Rhoam with a great evil known as the Calamity Ganon in Hyrule Castle."

In [44]:
loc_query = print_query_results(doc, 'What is a location in the story?')
print(loc_query[1])

NbestPrediction(text='Hyrule Castle.', start_logit=2.1551175117492676, end_logit=2.5169548988342285)
NbestPrediction(text='Hyrule Kingdom', start_logit=0.20551830530166626, end_logit=0.6382617354393005)
NbestPrediction(text='Castle.', start_logit=-2.103416681289673, end_logit=2.5169548988342285)
NbestPrediction(text='in Hyrule Castle.', start_logit=-2.175450086593628, end_logit=2.5169548988342285)
NbestPrediction(text='Hyrule', start_logit=2.1551175117492676, end_logit=-2.271923542022705)
NbestPrediction(text='the Great Plateau,', start_logit=-0.5107864737510681, end_logit=-0.5040255188941956)
NbestPrediction(text='ruined country of Hyrule Kingdom', start_logit=-2.4173471927642822, end_logit=0.6382617354393005)
NbestPrediction(text='Great Plateau,', start_logit=-1.3218371868133545, end_logit=-0.5040255188941956)
NbestPrediction(text='Zelda,', start_logit=-0.898211658000946, end_logit=-1.5138918161392212)
NbestPrediction(text='the Shrine of Resurrection', start_logit=-1.1728227138519287

In [None]:
char_query = print_query_results(doc, 'Who is somebody in the story?')
print(char_query[1])

In [None]:
a_query = print_query_results(doc, 'Who are in Hyrule Castle in the story?')
print(a_query[1])

In [4]:
with open(input_text) as f:
    doc = f.read()

In [None]:
from bert import QA

# !/usr/bin/env python3

import json
import os
import numpy as np
import tensorflow as tf
import matplotlib
import subprocess

matplotlib.use('Agg')

import networkx as nx
import json
import argparse
import matplotlib.pyplot as plt
import random
import string


def readGraph(file):
    with open(file, 'r') as fp:
        # G = nx.parse_edgelist(, nodetype = int)

        dump = fp.read()
        j = json.loads(dump)
        locs = j.keys()
        objs = []
        for loc in j.values():
            objs += loc['objects']
    return locs, objs, []


loc2loc_templates = ["What location is next to {} in the story?"]

loc2obj_templates = ["What is in {} in the story?", ]
obj2loc_templates = ["What location is {} in the story?", ]

loc2char_templates = ["Who is in {} in the story?", ]
char2loc_templates = ["What location is {} in the story?", ]

conjunctions = ['and', 'or', 'nor']
articles = ["the", 'a', 'an', 'his', 'her', 'their', 'my', 'its', 'those', 'these', 'that', 'this', 'the']
pronouns = [" He ", " She ", " he ", " she "]


class World:
    def __init__(self, locs, objs, relations, args):
        self.graph = nx.Graph()
        self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled")
        self.graph.add_nodes_from(objs, type='object')
        self.graph.add_edges_from(relations)

        self.locations = {v for v in locs}
        self.objects = {v for v in objs}
        self.edge_labels = {}

        self.args = args

        # init GPT-2
        with open(args.input_text) as f:
            self.input_text = f.read()
        
        print(self.input_text)

        self.model = QA('model/albert-large-squad')

    def is_connected(self):
        return len(list(nx.connected_components(self.graph))) == 1

    def query(self, query, nsamples=10, cutoff=8):
        return self.model.predictTopK(self.input_text, query, nsamples, cutoff)

    def generateNeighbors(self, nsamples=100):
        self.candidates = {}
        for u in self.graph.nodes:
            self.candidates[u] = {}
            if self.graph.nodes[u]['type'] == "location":
                self.candidates[u]['location'] = self.query(random.choice(loc2loc_templates).format(u), nsamples)
                self.candidates[u]['object'] = self.query(random.choice(loc2obj_templates).format(u), nsamples)
                self.candidates[u]['character'] = self.query(random.choice(loc2char_templates).format(u), nsamples)
            if self.graph.nodes[u]['type'] == "object":
                self.candidates[u]['location'] = self.query(random.choice(obj2loc_templates).format(u), nsamples)
            if self.graph.nodes[u]['type'] == "character":
                self.candidates[u]['location'] = self.query(random.choice(char2loc_templates).format(u), nsamples)

    def relatedness(self, u, v, type='location'):

        s = 0
        u2v, probs = self.candidates[u][type]

        if u2v is not None:
            for c, p in zip(u2v, probs):
                a = set(c.text.split()).difference(articles)
                b = set(v.split()).difference(articles)

                # find best intersect
                best_intersect = 0
                for x in self.graph.nodes:
                    xx = set(x.split()).difference(articles)
                    best_intersect = max(best_intersect, len(a.intersection(xx)))

                # increment if answer is best match BoW
                if len(a.intersection(b)) == best_intersect:
                    s += len(a.intersection(b)) * p

                # naive method
                # s += len(a.intersection(b)) * p

        v2u, probs = self.candidates[v]['location']

        if v2u is not None:
            for c, p in zip(v2u, probs):
                a = set(c.text.split()).difference(articles)
                b = set(u.split()).difference(articles)

                # find best intersect
                best_intersect = 0
                for x in self.graph.nodes:
                    xx = set(x.split()).difference(articles)
                    best_intersect = max(best_intersect, len(a.intersection(xx)))

                # increment if answer is best match BoW
                if len(a.intersection(b)) == best_intersect:
                    s += len(a.intersection(b)) * p

                # naive method
                # s += len(a.intersection(b)) * p

        return s

    def extractEntity(self, query, threshold=0.05, cutoff=0):
        preds, probs = self.query(query, self.args.nsamples, cutoff)

        if preds is None:
            print("NO ANSWER FOUND")
            return None, 0

        for pred, prob in zip(preds, probs):
            t = pred.text
            p = prob
            print('> ', t, p)
            if len(t) < 1:
                continue
            if p > threshold and "MASK" not in t:

                # find a more minimal candidate if possible
                for pred, prob in zip(preds, probs):
                    if t != pred.text and pred.text in t and prob > threshold and len(pred.text) > 2:
                        t = pred.text
                        p = prob
                        break

                t = t.strip(string.punctuation)
                remove = t

                # take out leading articles for cleaning
                words = t.split()
                if words[0].lower() in articles:
                    remove = " ".join(words[1:])
                    words[0] = words[0].lower()
                    t = " ".join(words[1:])
                print(remove)

                self.input_text = self.input_text.replace(remove, '[MASK]').replace('  ', ' ').replace(' .', '.')
                return t, p
            # else:
            # find a more minimal candidate if possible
            # for pred, prob in zip(preds, probs):
            #     if prob > threshold and "MASK" not in pred.text and len(pred.text) > 2 and pred.text in t:
            #         t = pred.text.strip(string.punctuation)
            #         p = prob
            #         self.input_text = self.input_text.replace(t, '[MASK]').replace('  ', ' ').replace(' .', '.')
            #         print(t, p)
            #         return t, p

        return None, 0

    def generate(self):

        locs = []
        objs = []
        chars = []

        # set thresholds/cutoffs
        threshold = 0.05

        if self.args.cutoffs == 'fairy':
            cutoffs = [6.5, -7, -5]  # fairy
        elif self.args.cutoffs == 'mystery':
            cutoffs = [3.5, -7.5, -6]  # mystery
        else:
            cutoffs = [float(i) for i in self.args.cutoffs.split()]
            assert len(cutoffs) == 3

        # save input text
        tmp = self.input_text[:]

        # add chars
        print("=" * 20 + "\tcharacters\t" + "=" * 20)
        self.input_text = tmp
        primer = "Who is somebody in the story?"
        cutoff = cutoffs[0]
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            if len(chars) > 1:
                cutoff = cutoffs[0]
            chars.append(t)
            t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)

        print("=" * 20 + "\tlocations\t" + "=" * 20)

        # add locations
        self.input_text = tmp
        primer = "Where is the location in the story?"
        cutoff = cutoffs[1]
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            locs.append(t)

            if len(locs) > 1:
                cutoff = cutoffs[1]

            t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)

        print("=" * 20 + "\tobjects\t\t" + "=" * 20)

        # add objects
        self.input_text = tmp
        primer = "What is an object in the story?"
        cutoff = cutoffs[2]
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            if len(objs) > 1:
                cutoff = cutoffs[2]
            objs.append(t)
            t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        self.input_text = tmp

        self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled")
        self.graph.add_nodes_from(chars, type='character', fillcolor="orange", style="filled")
        self.graph.add_nodes_from(objs, type='object', fillcolor="white", style="filled")

        # with open('stats.txt', 'a') as f:
        # f.write(args.input_text + "\n")
        # f.write(str(len(locs)) + "\n")
        # f.write(str(len(chars)) + "\n")
        # f.write(str(len(objs)) + "\n")
        self.autocomplete()

    def autocomplete(self):

        self.generateNeighbors(self.args.nsamples)

        print("=" * 20 + "\trelations\t" + "=" * 20)
        while not self.is_connected():
            components = list(nx.connected_components(self.graph))
            best = (-1, next(iter(components[0])), next(iter(components[1])))

            main = components[0]

            loc_done = True
            for c in components[1:]:
                for v in c:
                    if self.graph.nodes[v]['type'] == 'location':
                        loc_done = False

            for u in main:
                if self.graph.nodes[u]['type'] != 'location':
                    continue

                for c in components[1:]:
                    for v in c:
                        if not loc_done and self.graph.nodes[v]['type'] != 'location':
                            continue
                        best = max(best, (self.relatedness(u, v, self.graph.nodes[v]['type']), u, v))

            _, u, v = best

            # attach randomly if empty or specified
            if _ == 0 or args.random:
                candidates = []
                for c in components[0]:
                    if self.graph.nodes[c]['type'] == 'location':
                        candidates.append(c)
                u = random.choice(candidates)
            if self.graph.nodes[u]['type'] == 'location' and self.graph.nodes[v]['type'] == 'location':
                type = "connected to"
            else:
                type = "located in"
            print("{} {} {}".format(v, type, u))
            self.graph.add_edge(v, u, label=type)
            self.edge_labels[(v, u)] = type

    def export(self, filename="graph.dot"):
        nx.nx_pydot.write_dot(self.graph, filename)
        nx.write_gml(self.graph, "graph.gml", stringizer=None)

    def draw(self, filename="./graph.svg"):
        self.export()

        if args.write_sfdp:
            cmd = "sfdp -x -Goverlap=False -Tsvg graph.dot".format(filename)
            returned_value = subprocess.check_output(cmd, shell=True)
            with open(filename, 'wb') as f:
                f.write(returned_value)
            cmd = "inkscape -z -e {}.png {}.svg".format(filename[:-4], filename[:-4])
            returned_value = subprocess.check_output(cmd, shell=True)
        else:
            nx.draw(self.graph, with_labels=True)
            plt.savefig(filename[:-4] + '.png')

In [3]:
import random
random.seed(0)

In [20]:
class Args:
	def __init__(self, args):
		self.input_text = args['input_text']
		self.length = args['length']
		self.batch_size = args['batch_size']
		self.temperature = args['temperature']
		self.model_name = args['model_name']
		self.seed = args['seed']
		self.nsamples = args['nsamples']
		self.cutoffs = args['cutoffs']
		self.write_sfdp = args['write_sfdp']
		self.random = args['random']
                
args = Args({
	# 'input_text': 'resolved_zelda_botw.txt',
	# 'input_text': 'resolved_rapunzel.txt',
	'input_text': 'rapunzel.txt',
	# 'input_text': input_text,
	'length': 10,
	'batch_size': 1,
	'temperature': 1,
	'model_name' : '117M',
	'seed' : 0,
	'nsamples' : 15,
	'cutoffs' : '11.5 15 12',
	# 'cutoffs' : 'fairy',
	'write_sfdp': False,
	'random': False
})

In [21]:
world = World([], [], [], args)
world.generate()

A lonely couple, who want a child, live next to a walled garden belonging to a sorceress. The wife, experiencing the cravings associated with the arrival of her long-awaited pregnancy, notices some rapunzel, growing in the garden and longs for it. She refuses to eat anything else and gets sick, and the husband begins to fear for her life. One night, her husband breaks into the garden to get some for her. She makes a salad out of it and greedily eats it. It tastes so good that she longs for more. So her husband goes to get some more for her. As he scales the wall to return home, the sorceress catches him and accuses him of theft. He begs for mercy, and she agrees to be lenient, and allows him to take all the rapunzel he wants, on condition that the baby be given to her when it's born. Desperate, he agrees. When his wife has a baby girl, the sorceress takes her to raise as her own and names her "Rapunzel" after the plant her mother craved. She grows up to be the most beautiful child in t

In [5]:
world.draw('test.svg')
world.export('test.dot')

The iterable function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use np.iterable instead.
  if not cb.iterable(width):


# QA with roberta base trained on SQuAD2

In [None]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline

model_name = "deepset/roberta-base-squad2"
# a) Get predictions
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)

In [None]:
QA_input = {
    'question': 'Who is somebody in the story?',
    'context': doc
}
res = nlp(QA_input)
print(res)

In [None]:
res

In [None]:
# b) Load model & tokenizer
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
pt_batch = tokenizer(
    [doc],
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt"
)

In [None]:
pt_outputs = model(**pt_batch)

In [None]:
from torch import nn

pt_predictions = nn.functional.softmax(pt_outputs, dim=-1)
print(pt_predictions)