### Importing libraries

In [22]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import math
import IPython
from pyvis.network import Network
import wikipedia

### Loading the relation extraction model

In [14]:
# Laoding model and Tokenizers
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

### Offering 2 types of text to Knowledge Base
1. Short text to Knowledge Base (Feeding summarised text to build KB)
2. Long Text to Knowledge Base (Feeding non-summarised text to build KB)

In [15]:
def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    "head": subject.strip(),
                    "type": relation.strip(),
                    "tail": object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    "head": subject.strip(),
                    "type": relation.strip(),
                    "tail": object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            "head": subject.strip(),
            "type": relation.strip(),
            "tail": object_.strip()
        })
    
    return relations

### Implementing a KB class to deal with adding new relations to the Knowledge base

In [20]:
class KB():
    def __init__(self):
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
    
    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
    
    def merge_relations(self, r1):
        r2 = [r for r in self.relations if self.are_relations_equal(r1, r)[0]]
        spans_to_add = [span for span in r1["meta"]["spans"] if span not in r2["meta"]["spans"]]
        r2["meta"]["spans"] += spans_to_add
    
    def add_relations(self, r):
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def print(self):
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

### Defining a function that returns KB object with relations extracted from a short text

In [17]:
def from_small_text_to_kb(text, verbose=False):
    kb = KB()

    # Tokenizer text
    model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
    if verbose:
        print(f"Num tokens: {len(model_inputs)}")

    # Generate
    gen_kwargs = {
        "max_length": 216,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": 3
    }
    generated_tokens = model.generate(
        **model_inputs,
        **gen_kwargs,
    )
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

    # Creating the KB
    for sentence_pred in decoded_preds:
        relations = extract_relations_from_model_output(sentence_pred)
        for r in relations:
            kb.add_relations(r)
    
    return kb


### Loading text from the CMU intro lecture transcript and build a knowledge graph on it

In [18]:
cmu_lecture_transcript_path = "cmu_computer_graphics_intro_voice_transcribed.txt"
with open(cmu_lecture_transcript_path) as src:
    cmu_lecture_text = src.read()

print(cmu_lecture_text)

 Welcome to Computer Graphics 15462-662 at Carnegie Mellon University. I'm Kenan Crane. I'm a professor of computer science and robotics. And I also do research in computer graphics, so specifically in the area of geometric algorithms. The purpose of this video is to give you all the information that you'll need to succeed this semester. So periodically we'll upload little videos to cover administrative things, to talk about what's been going on this week, and to answer any significant questions that have come up. I should also say that all the information today is available on the course webpage at 15462.courses.cs.cmu.edu. So please go ahead, check out that link, read through especially the course info page in detail because there's a lot of things that I won't say here in this video but that are important for you to know as you go through the course. we have a great set of TAs this semester. So if you have any questions, please at any time, feel free to email them, email me, post a 

The KB build is below

In [19]:
kb = from_small_text_to_kb(cmu_lecture_text, verbose=True)
kb.print()

Num tokens: 2
Relations:
  {'head': 'Kenan Crane', 'type': 'field of work', 'tail': 'computer graphics'}
  {'head': 'computer graphics', 'type': 'part of', 'tail': 'computer science'}
  {'head': 'Computer Graphics 15462-662', 'type': 'main subject', 'tail': 'computer science'}


### Creating a text_to_kb function that manages the texts with spanning logic

In [21]:
def from_text_to_kb(text, span_length=128, verbose=False):
    # tokenize whole text
    inputs = tokenizer([text], return_tensors="pt")

    # compute span boundaries
    num_tokens = len(inputs["input_ids"][0])
    if verbose:
        print(f"Input has {num_tokens} tokens")
    num_spans = math.ceil(num_tokens / span_length)
    if verbose:
        print(f"Input has {num_spans} spans")
    overlap = math.ceil((num_spans * span_length - num_tokens) / 
                        max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i,
                                 start + span_length * (i + 1)])
        start -= overlap
    if verbose:
        print(f"Span boundaries are {spans_boundaries}")

    # transform input with spans
    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                  for boundary in spans_boundaries]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
                    for boundary in spans_boundaries]
    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }

    # generate relations
    num_return_sequences = 3
    gen_kwargs = {
        "max_length": 256,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": num_return_sequences
    }
    generated_tokens = model.generate(
        **inputs,
        **gen_kwargs,
    )

    # decode relations
    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    # create kb
    kb = KB()
    i = 0
    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)
        for relation in relations:
            relation["meta"] = {
                "spans": [spans_boundaries[current_span_index]]
            }
            kb.add_relation(relation)
        i += 1

    return kb

### Performing entity linking
This is to make sure that entities that are similar to each other are merged into a single entity and prevents overly large clusters (We are using Wikipedia for this purpose)

In [23]:
# This is the modified version of the KB class
class KB():
    def __init__(self):
        self.entities = {}
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
    
    def get_wikipedia_data(self, candidate_entity):
        try:
            page = wikipedia.page(candidate_entity, auto_suggest=False)
            entity_data = {"title": page.title,
                           "url": page.url,
                           "summary": page.summary}
            return entity_data
        except:
            return None
        
    def add_entity(self, e):
        self.entities[e["title"]] = {k:v for k, v in e.items() if k != "title"}

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
    
    def merge_relations(self, r1):
        r2 = [r for r in self.relations if self.are_relations_equal(r1, r)[0]]
        spans_to_add = [span for span in r1["meta"]["spans"] if span not in r2["meta"]["spans"]]
        r2["meta"]["spans"] += spans_to_add
    
    def add_relations(self, r):
        # We now first check on wikipedia
        candidate_entities = [r["head"], r["tail"]]
        entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]

        # If one of the entities does not exist, stop
        if any(ent is None for ent in entities):
            return
        
        # Managing new entities
        for e in entities:
            self.add_entity(e)

        # Renaming relation entities with their wikipedia titles
        r["head"] = entities[0]["title"]
        r["tail"] = entities[1]["title"]

        # Managing the new relation
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def print(self):
        print("Entities:")
        for e in self.entities.items():
            print(f"  {e}")
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

### Using Pyvis to visualise the knowledge base that is built
Defining a save_network_html function that:
1. Initializes an empty directed pyvis network
2. Add the knowledge base entities as nodes
3. Adds the knowledge base relations as edges
4. Save the network in HTML file

In [24]:
def save_network_html(kb, filename="network.html"):
    # Creating the empty network
    net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")

    # nodes
    color_entity = "#00FF00"
    for e in kb.entities:
        net.add_node(e, shape="circle", color=color_entity)

    # Edges
    for r in kb.relations:
        net.add_edge(r["head"], r["tail"], title=r["type"], label=r["type"])

    # Saving the network
    net.repulsion(node_distance=200,
                  central_gravity=0.2,
                  spring_length=200,
                  spring_strength=0.05,
                  damping=0.09)
    net.set_edge_smooth("dynamic")
    net.show(filename)