In [1]:
! pip install torch transformers numpy pandas seaborn matplotlib tqdm wikipedia

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import math
import torch
import wikipedia
import pandas as pd
import torch
from transformers import AutoTokenizer, GPTNeoForCausalLM, AutoModelForSeq2SeqLM

In [3]:
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

In [4]:
def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    return relations

In [5]:
class KB():
    def __init__(self):
        self.entities = {}
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def merge_relations(self, r1):
        r2 = [r for r in self.relations
              if self.are_relations_equal(r1, r)][0]
        spans_to_add = [span for span in r1["meta"]["spans"]
                        if span not in r2["meta"]["spans"]]
        r2["meta"]["spans"] += spans_to_add

    def get_wikipedia_data(self, candidate_entity):
        try:
            page = wikipedia.page(candidate_entity, auto_suggest=False)
            entity_data = {
                "title": page.title,
                "url": page.url,
                "summary": page.summary
            }
            return entity_data
        except:
            return None

    def add_entity(self, e):
        self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}

    def add_relation(self, r):
        # check on wikipedia
        candidate_entities = [r["head"], r["tail"]]
        entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]

        # if one entity does not exist, stop
        if any(ent is None for ent in entities):
            return

        # manage new entities
        for e in entities:
            self.add_entity(e)

        # rename relation entities with their wikipedia titles
        r["head"] = entities[0]["title"]
        r["tail"] = entities[1]["title"]

        # manage new relation
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def print(self):
        print("Entities:")
        for e in self.entities.items():
            print(f"  {e}")
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

In [6]:
def from_text_to_kb(text, span_length=128, verbose=False):
    # tokenize whole text
    inputs = tokenizer([text], return_tensors="pt")

    # compute span boundaries
    num_tokens = len(inputs["input_ids"][0])
    if verbose:
        print(f"Input has {num_tokens} tokens")
    num_spans = math.ceil(num_tokens / span_length)
    if verbose:
        print(f"Input has {num_spans} spans")
    overlap = math.ceil((num_spans * span_length - num_tokens) / 
                        max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i,
                                 start + span_length * (i + 1)])
        start -= overlap
    if verbose:
        print(f"Span boundaries are {spans_boundaries}")

    # transform input with spans
    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                  for boundary in spans_boundaries]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
                    for boundary in spans_boundaries]
    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }

    # generate relations
    num_return_sequences = 3
    gen_kwargs = {
        "max_length": 256,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": num_return_sequences
    }
    generated_tokens = model.generate(
        **inputs,
        **gen_kwargs,
    )

    # decode relations
    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    # create kb
    kb = KB()
    i = 0
    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)
        for relation in relations:
            relation["meta"] = {
                "spans": [spans_boundaries[current_span_index]]
            }
            kb.add_relation(relation)
        i += 1

    return kb

In [7]:
df = pd.read_csv('output_final.csv', sep=';')

data = []

for index, row in df.iterrows():
  id = row['id']
  text = row['title']
  kb = from_text_to_kb(text, verbose=True)
  for r in kb.relations:
    data.append([id, r['head'], r['type'], r['tail']])
    ndf = pd.DataFrame(data, columns=['id', 'head', 'type', 'tail'])
  
ndf.to_csv('rebel_final.csv')

Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 12 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 51 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 44 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 55 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 11 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 22 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 22 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 19 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 13 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 13 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 13 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 24 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 17 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 17 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 21 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 12 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 21 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 10 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 8 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 14 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 20 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 9 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 20 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 9 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 10 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 27 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 17 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 8 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 14 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 17 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 19 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 26 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 25 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 6 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 8 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 12 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 8 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 14 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 31 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 9 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 27 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 8 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 14 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 14 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 14 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 19 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 11 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 11 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 11 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 17 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 17 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 27 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 12 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 17 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 13 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 9 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 17 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 11 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 11 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 9 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 42 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 42 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 42 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 42 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 42 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 42 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 22 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 16 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]




  lis = BeautifulSoup(html).find_all('li')


Input has 20 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 20 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 20 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 10 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 51 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 51 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 21 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 21 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 57 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 57 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 57 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 57 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 57 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 57 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 57 tokens
Input has 1 sp



  lis = BeautifulSoup(html).find_all('li')


Input has 23 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 23 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 23 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 23 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 23 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 23 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 18 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 61 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Input has 61 tokens
Input has 1 sp