In [2]:
import warnings
warnings.filterwarnings('ignore')

import spacy
from spacy.tokens import DocBin
import pandas as pd

nlp = spacy.blank("en")

from utilities import read_csv
from tqdm import tqdm
import json
import re
import pickle
from itertools import product
import random

    
with open("data/dataset_doc_list.pkl", "rb") as f:
    doc_list = pickle.load(f)

In [3]:
def num_doc(doc):
    num_list = []
    chunk_range = []
    for chunk in doc.noun_chunks:
        num_list.append((chunk.start, chunk))
        chunk_range += list(range(chunk.start, chunk.end))

    for token in doc:
        if token.i not in chunk_range:
            num_list.append((token.i, token))

    return {f"{item[1].text} ({i})": item[1] for item, i in zip(sorted(num_list), range(len(sorted(num_list))))}

def process_facts(doc, facts):
    numbered_list = num_doc(doc)
    processed_facts = []
    for fact in facts:
        if len(fact.split(" :: ")) == 3:
            [h, r, t] = fact.split(" :: ")
            if type(numbered_list[h]) != spacy.tokens.token.Token and type(numbered_list[t]) != spacy.tokens.token.Token:
                head = [doc[i] for i in range(numbered_list[h].start, numbered_list[h].end)]              
                tail = [doc[i] for i in range(numbered_list[t].start, numbered_list[t].end)]                
                relation = [numbered_list[item] for item in numbered_list if item in r]

                processed_facts.append((head, relation, tail))
    
    entities, relations, ents, rels = [], [], [], []

    for (head, relation, tail) in processed_facts:
        if head not in entities:
            entities.append(head)
        if tail not in entities:
            entities.append(tail)
        rels += relation

    for ent in entities:
        ents += ent

    ents = list(set(ents))
    rels = list(set(rels))

    pairs = [(head, tail) for (head, relation, tail) in processed_facts]
    all_pairs = [(h, t) for (h, t) in product(entities, entities) if ((h, t) not in pairs) and (h != t) ]
    null_edges = random.sample(all_pairs, min(len(all_pairs), len(processed_facts)))

    for (head, tail) in null_edges:
        processed_facts.append((head, [], tail))
    
    return processed_facts, ents, rels

def create_blank_doc(head, tail, doc):
    text = list(doc.text)
    head_start, head_end = head[0].idx, head[-1].idx + len(head[-1].text)
    tail_start, tail_end = tail[0].idx, tail[-1].idx + len(tail[-1].text)

    if head_start < tail_start:
        text[head_start: head_end] = list("{HEAD ~ ") +  text[head_start: head_end] + list("}")
        text[tail_start + 9: tail_end + 9] = list("{TAIL ~ ") +  text[tail_start + 9: tail_end + 9] + list("}")
    else:
        text[tail_start: tail_end] = list("{TAIL ~ ") +  text[tail_start: tail_end] + list("}")
        text[head_start + 9: head_end + 9] = list("{HEAD ~ ") +  text[head_start + 9: head_end + 9] + list("}")

    blank_doc = nlp("".join(text))
    
    return blank_doc

def create_sequence(head, relation, tail, doc):
    head_str = [(token.text, token.tag_, 'HEAD') for token in head]
    tail_str = [(token.text, token.tag_, 'TAIL') for token in tail]

    head_replacement = [('{', 'HEAD', 'HEAD'), ('HEAD', 'HEAD', 'HEAD'), ('~', 'HEAD', 'HEAD')] + head_str + [('}', 'HEAD', 'HEAD')]
    tail_replacement = [('{', 'TAIL', 'TAIL'), ('TAIL', 'TAIL', 'TAIL'), ('~', 'TAIL', 'TAIL')] + tail_str + [('}', 'TAIL', 'TAIL')]

    sequence = []
    for token in doc:
        if token in head:
            sequence.append((token.text, token.tag_, "HEAD"))
        elif token in tail:
            sequence.append((token.text, token.tag_, "TAIL"))
        elif token in relation:
            sequence.append((token.text, token.tag_, "REL"))
        else:
            sequence.append((token.text, token.tag_, "OTH"))

    for i in range(len(sequence)):
        if sequence[i:i+len(head)] == head_str:
            sequence[i:i+len(head)] = head_replacement
            break

    for i in range(len(sequence)):
        if sequence[i:i+len(tail)] == tail_str:
            sequence[i:i+len(tail)] = tail_replacement
            break
    
    return sequence

In [None]:
doc_bin = DocBin()
test_bin = DocBin()
train_count = 0
for patent in tqdm(doc_list):    
    for sentence in doc_list[patent][20:21]:
        doc = sentence["doc"]
        facts = sentence["facts"]
        
        processed_facts, ents, rels = process_facts(doc, facts)
        for (head, relation, tail) in processed_facts:
            blank_doc = create_blank_doc(head, tail, doc)
            sequence = create_sequence(head, relation, tail, doc)
            print(blank_doc)
            print([(item[0], item[2]) for item in sequence])
            
            text = list(doc.text)
            
            for i in range(len(sequence)):
                blank_doc[i].tag_ = sequence[i][2]
                
            train_count += 1
            if train_count < 336800:
                doc_bin.add(blank_doc)
            else:
                test_bin.add(blank_doc)
    break

doc_bin.to_disk("train-2.spacy")
test_bin.to_disk("test-2.spacy")               