In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import math
import torch
import wikipedia

import IPython
from pyvis.network import Network

# Load REBEL model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")

model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
# Move model to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(device)

cuda


In [2]:
def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    return relations

class KB():
    def __init__(self):
        self.entities = {} # { entity_title: {...} }
        self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
          # meta: { article_url: { spans: [...] } } ]
        self.sources = {} # { article_url: {...} }

    def merge_with_kb(self, kb2):
        for r in kb2.relations:
            article_url = list(r["meta"].keys())[0]
            source_data = kb2.sources[article_url]
            self.add_relation(r, source_data["article_title"],
                              source_data["article_publish_date"])

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def merge_relations(self, r2):
        r1 = [r for r in self.relations
              if self.are_relations_equal(r2, r)][0]

        # if different article
        article_url = list(r2["meta"].keys())[0]
        if article_url not in r1["meta"]:
            r1["meta"][article_url] = r2["meta"][article_url]

        # if existing article
        else:
            spans_to_add = [span for span in r2["meta"][article_url]["spans"]
                            if span not in r1["meta"][article_url]["spans"]]
            r1["meta"][article_url]["spans"] += spans_to_add

    def get_wikipedia_data(self, candidate_entity):
        # try:

        #     page = wikipedia.page(candidate_entity, auto_suggest=False)
        #     entity_data = {
        #         "title": page.title,
        #         "url": page.url,
        #         "summary": page.summary
        #     }
        #     return entity_data
        # except:
        entity_data = {
            "title": candidate_entity,
            "url": "NA",
            "summary": "NA"
        }
        return entity_data

    def add_entity(self, e):
        self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}

    def add_relation(self, r, article_title, article_publish_date):
        # check on wikipedia
        candidate_entities = [r["head"], r["tail"]]
        entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]

        # if one entity does not exist, stop
        if any(ent is None for ent in entities):
            return

        # manage new entities
        for e in entities:
            self.add_entity(e)

        # rename relation entities with their wikipedia titles
        r["head"] = entities[0]["title"]
        r["tail"] = entities[1]["title"]

        # add source if not in kb
        article_url = list(r["meta"].keys())[0]
        if article_url not in self.sources:
            self.sources[article_url] = {
                "article_title": article_title,
                "article_publish_date": article_publish_date
            }

        # manage new relation
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def print(self):
        print("Entities:")
        for e in self.entities.items():
            print(f"  {e}")
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")
        print("Sources:")
        for s in self.sources.items():
            print(f"  {s}")

In [3]:
def computeSpans(inputs, verbose=False, span_length=128):
    num_tokens = len(inputs["input_ids"][0])
    if verbose:
        print(f"Input has {num_tokens} tokens")
    num_spans = math.ceil(num_tokens / span_length)
    if verbose:
        print(f"Input has {num_spans} spans")
    overlap = math.ceil((num_spans * span_length - num_tokens) /
                        max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i,
                                 start + span_length * (i + 1)])
        start -= overlap
    if verbose:
        print(f"Span boundaries are {spans_boundaries}")

     # transform input with spans
    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]].to(device) for boundary in spans_boundaries]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]].to(device) for boundary in spans_boundaries]

    return tensor_ids, tensor_masks, spans_boundaries

## This function splits large texts to handle better
def from_text_to_kb(text, article_url="NA", span_length=128, article_title=None, article_publish_date=None, verbose=False):
    # tokenize whole text
    inputs = tokenizer([text], return_tensors="pt")

    # compute span boundaries
    tensor_ids, tensor_masks, spans_boundaries = computeSpans(inputs, verbose, span_length)

    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }

    # generate relations
    num_return_sequences = 3
    gen_kwargs = {
        "max_length": 256,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": num_return_sequences
    }
    generated_tokens = model.generate(
        **inputs,
        **gen_kwargs,
    )


    # decode relations
    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    # create kb
    kb = KB()
    i = 0
    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)
        for relation in relations:
            relation["meta"] = {
                article_url: {
                    "spans": [spans_boundaries[current_span_index]]
                }
            }
            kb.add_relation(relation, article_title, article_publish_date)
        i += 1

    return kb

In [4]:
import networkx as nx

G = nx.Graph()
def save_network(kb, filename="network2.html"):
    # create network
    # net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")

    # nodes
    color_entity = "#00FF00"

    G.add_nodes_from(kb.entities)

    # edges
    for r in kb.relations:
        G.add_edge(r["head"], r["tail"], title=r["type"], color=color_entity)

def save_network_html(kb, filename="network.html"):
    # create network
    net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")

    # nodes
    color_entity = "#00FF00"
    for e in kb.entities:
        net.add_node(e, shape="circle", color=color_entity)

    # edges
    for r in kb.relations:
        net.add_edge(r["head"], r["tail"],
                    title=r["type"], label=r["type"])

    # save network
    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename, notebook=False)


In [None]:

import pandas as pd
import gc
import pickle

torch.cuda.empty_cache()

def save_object(obj, filename):
    with open(filename, 'wb') as outp:  # Overwrites any existing file.
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)

# Create KB
kbGlobal = KB()

def MergeKB(txt):
    print(txt.name)
    kbLocal = from_text_to_kb(txt['content'], span_length=128, verbose=False)
    global kbGlobal
    if kbGlobal:
        kbGlobal.merge_with_kb(kbLocal)
    else:
        kbGlobal = kbLocal


    if int(txt.name+1) % 250 == 0:
        print('Saving KB')
        # save as csv
        # read if exists
        # if os.path.isfile('./data/HWdb_2024_geocoded_KB.csv'):
        #     kbSaved = pd.read_csv('./data/HWdb_2024_geocoded_KB.csv')
        #     kbGlobal.merge_with_kb(kbSaved)
        #     kbGlobal.to_csv(f'./data/HWdb_2024_geocoded_KB.csv', index=False)
        # else:
        #     kbGlobal.to_csv(f'./data/HWdb_2024_geocoded_KB.csv', index=False)
        
        save_object(kbGlobal, f'./data/HWdb_2024_geocoded_KB_span128_{txt.name+1}.pickle')
        kbGlobal = None
        gc.collect()
        torch.cuda.empty_cache()
file = pd.read_csv('./data/HWdb_2024_geocoded.csv')

# # subset 100 random rows
# file = file.sample(n=30)

# Split the workload
# After every 200 rows merged with kb, save the network, clear gpu memory and merge with next 200 rows
# file['content'].apply(lambda x: MergeKB(x))

file.apply(lambda x: MergeKB(x),  axis=1)

In [None]:
filename="code-network1.html"
save_network_html(kb, filename=filename)
# IPython.display.HTML(filename=filename)

# import matplotlib.pyplot as plt
# nx.draw_shell(G, with_labels=True, font_weight='bold')
# plt.show

In [10]:
# Read pickle
# parse into rdf using rdf lib
# export
import pickle

pickle1 = pickle.load(open('./data/HWdb_2024_geocoded_KB500.pickle', 'rb'))
pickle2 = pickle.load(open('./data/HWdb_2024_geocoded_KB1000.pickle', 'rb'))


from rdflib import Graph, URIRef, Literal, Namespace

# Create an RDF graph
graph = Graph()

# Define a custom namespace for your data
my_ns = Namespace("http://example.org/")

# Iterate over your entity-relationship data and add triples to the graph
for entry in pickle1.relations:
    # Create RDF subjects, predicates, and objects
    # Create RDF subjects, predicates, and objects
    subject = URIRef(my_ns[entry['head'].replace(" ", "_")])  # Replace spaces with underscores
    predicate = URIRef(my_ns[entry['type'].replace(" ", "_")])  # Replace spaces with underscores
    obj = Literal(entry['tail'])

    # Add triples to the graph
    graph.add((subject, predicate, obj))

# Serialize the RDF graph to a file (e.g., in Turtle format)
graph.serialize(destination='output.rdf', format='xml')



<Graph identifier=Naabacf0a733f491594d7aa5b881a8360 (<class 'rdflib.graph.Graph'>)>