In [1]:
import json

def get_transcript(file_path):
    with open(file_path, 'r') as file:
        df = json.load(file)
    transcript = df["video_transcript"]
    return transcript

In [2]:
import re

def clean_transcript(transcript):
    # Remove new line characters
    transcript = re.sub('\n', ' ', transcript)
    # Remove multiple spaces
    transcript = re.sub(' +', ' ', transcript)
    # Remove leading and trailing spaces
    transcript = transcript.strip()
    # Convert to lowercase
    transcript = transcript.lower()
    return transcript

In [3]:
import os

data_dir = "../data/"
for file in os.listdir(data_dir):
    if file.endswith(".json"):
        file_path = os.path.join(data_dir, file)
        transcript = clean_transcript(get_transcript(file_path))
        print(transcript)

does find the president libel. we'll see what that means. >> martha: thanks, eric. we do have this decision. we're waiting for the content of it. it's 92 pages long. it's no price that judge engoron has found the former president, donald trump, libel in this case. this is a civil case. there is about money, not criminality. just to rewind a little bit here. there's a lot of different cases to keep track on. in this case, letitia james said when she was running for attorney general that she was going to get trump. we'll play what she said. the decision is in. it's 92 pages long. james asked for $370 million in damages. she wants him to be prevented from doing business in new york state. so as we wait to download all of this and get a look at the number and as soon as we get it, we'll tell you what it is. this is something that the former president is watching closely. he has layers of judicial dates, appointments, decisions coming down in the middle of the south carolina primary, which 

In [4]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

In [5]:
def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    return relations

In [6]:
class KB():
    def __init__(self):
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def merge_relations(self, r1):
        r2 = [r for r in self.relations
              if self.are_relations_equal(r1, r)][0]
        spans_to_add = [span for span in r1["meta"]["spans"]
                        if span not in r2["meta"]["spans"]]
        r2["meta"]["spans"] += spans_to_add

    def add_relation(self, r):
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def print(self):
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

In [7]:
import math
import torch


def from_text_to_kb(text, span_length=128, verbose=False):
    # tokenize whole text
    inputs = tokenizer([text], return_tensors="pt")

    # compute span boundaries
    num_tokens = len(inputs["input_ids"][0])
    if verbose:
        print(f"Input has {num_tokens} tokens")
    num_spans = math.ceil(num_tokens / span_length)
    if verbose:
        print(f"Input has {num_spans} spans")
    overlap = math.ceil((num_spans * span_length - num_tokens) / 
                        max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i,
                                 start + span_length * (i + 1)])
        start -= overlap
    if verbose:
        print(f"Span boundaries are {spans_boundaries}")

    # transform input with spans
    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                  for boundary in spans_boundaries]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
                    for boundary in spans_boundaries]
    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }

    # generate relations
    num_return_sequences = 3
    gen_kwargs = {
        "max_length": 256,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": num_return_sequences
    }
    generated_tokens = model.generate(
        **inputs,
        **gen_kwargs,
    )

    # decode relations
    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    # create kb
    kb = KB()
    i = 0
    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)
        for relation in relations:
            relation["meta"] = {
                "spans": [spans_boundaries[current_span_index]]
            }
            kb.add_relation(relation)
        i += 1

    return kb
kb = from_text_to_kb("The cat is on the mat. The dog is on the mat.")
kb.print()

Relations:
  {'head': 'cat', 'type': 'opposite of', 'tail': 'dog', 'meta': {'spans': [[0, 128]]}}
  {'head': 'dog', 'type': 'opposite of', 'tail': 'cat', 'meta': {'spans': [[0, 128]]}}


In [8]:
def kb_to_json(kb):
    kb_json = {
        "relations": kb.relations
    }
    return kb_json

In [10]:
import json

data_dir = "../data/"
# data_dir = "../test_data/"
output_dir = "../knowledge_bases/"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

for file in os.listdir(data_dir):
    if file.endswith(".json"):
        file_path = os.path.join(data_dir, file)
        transcript = clean_transcript(get_transcript(file_path))
        print(transcript)
        kb = from_text_to_kb(transcript, verbose=True)
        # convert kb to json
        kb_json = kb_to_json(kb)
        # save kb to json
        output_file_path = os.path.join(output_dir, file)
        with open(output_file_path, 'w') as file:
            json.dump(kb_json, file)
        print(f"Saved knowledge base to {output_file_path}")

judge orders trump to pay $355 million in civil fraud trial
Input has 15 tokens
Input has 1 spans
Span boundaries are [[0, 128]]
Saved knowledge base to ../knowledge_bases/CNN-Judge orders Trump to pay $355 million in civil fraud trial.json
