In [1]:
from transformers import AutoTokenizer, pipeline
from datasets import load_dataset
import pandas as pd
import itertools
from difflib import SequenceMatcher
import json
import numpy as np

import importlib
import _RE
importlib.reload(_RE)
from _RE import join_text, merge_result, combine_entities

In [2]:
def get_docred_data(get_distant=False):
    docred_data = load_dataset('docred', trust_remote_code=True)
    train_annotated = pd.DataFrame(docred_data['train_annotated'])
    train_distant = None
    if get_distant:
        train_distant = pd.DataFrame(docred_data['train_distant'])
    test = pd.DataFrame(docred_data['test'])
    validation = pd.DataFrame(docred_data['validation'])

    return train_annotated, train_distant, test, validation

In [6]:
train, _, test, validation = get_docred_data()

In [3]:
def get_info(instance):
    sents_raw = instance['sents']
    sents = [' '.join(sublist) for sublist in sents_raw]

    vertexSet = instance['vertexSet']
    labels = instance['labels']

    return sents, vertexSet, labels

In [7]:
sents, entities, labels = get_info(train.iloc[0])

In [14]:
for i, e in enumerate(entities):
    print(i, e)

0 [{'name': 'Zest Airways, Inc.', 'sent_id': 0, 'pos': [0, 4], 'type': 'ORG'}, {'name': 'Asian Spirit and Zest Air', 'sent_id': 0, 'pos': [10, 15], 'type': 'ORG'}, {'name': 'AirAsia Zest', 'sent_id': 0, 'pos': [6, 8], 'type': 'ORG'}, {'name': 'AirAsia Zest', 'sent_id': 6, 'pos': [19, 21], 'type': 'ORG'}]
1 [{'name': 'Ninoy Aquino International Airport', 'sent_id': 3, 'pos': [4, 8], 'type': 'LOC'}, {'name': 'Ninoy Aquino International Airport', 'sent_id': 0, 'pos': [26, 30], 'type': 'LOC'}]
2 [{'name': 'Pasay City', 'sent_id': 0, 'pos': [31, 33], 'type': 'LOC'}]
3 [{'name': 'Metro Manila', 'sent_id': 0, 'pos': [34, 36], 'type': 'LOC'}]
4 [{'name': 'Philippines', 'sent_id': 0, 'pos': [38, 39], 'type': 'LOC'}, {'name': 'Philippines', 'sent_id': 4, 'pos': [13, 14], 'type': 'LOC'}, {'name': 'Republic of the Philippines', 'sent_id': 5, 'pos': [25, 29], 'type': 'LOC'}]
5 [{'name': 'Manila', 'sent_id': 1, 'pos': [13, 14], 'type': 'LOC'}, {'name': 'Manila', 'sent_id': 3, 'pos': [9, 10], 'type':

In [18]:
for i in range(len(labels['head'])):
    print(f"head: {labels['head'][i]}; tail: {labels['tail'][i]}; relation: {labels['relation_text'][i]}")

head: 0; tail: 2; relation: headquarters location
head: 0; tail: 4; relation: country
head: 12; tail: 4; relation: country
head: 2; tail: 4; relation: country
head: 2; tail: 3; relation: located in the administrative territorial entity
head: 4; tail: 3; relation: contains administrative territorial entity
head: 5; tail: 4; relation: country
head: 3; tail: 2; relation: contains administrative territorial entity
head: 3; tail: 4; relation: located in the administrative territorial entity
head: 3; tail: 4; relation: country
head: 1; tail: 2; relation: located in the administrative territorial entity
head: 1; tail: 4; relation: country
head: 10; tail: 4; relation: country


In [None]:
def make_triplets(vertexSet, labels):
    '''
    Returns a list of triplet of format <head, relation, tail>
    `head` and `tail` contains a list of "synonym" entites (e.g. Swedish and Sweden)
    `relation` contains relation_id and relation_text (which explain the relation, e.g. "country")
    '''

    names = []
    types = []
    triplets = []

    head = labels['head']
    tail = labels['tail']
    relation_id = labels['relation_id']
    relation = labels['relation_text']

    if not len(head) == len(tail) == len(relation):
        raise ValueError("Labels are not unform length")

    # Get names and types from vertexSet
    for entities in vertexSet:
        sub_names = [entity['name'] for entity in entities]
        sub_types = [entity['type'] for entity in entities]
        names.append(sub_names)
        types.append(sub_types)

    # Construct triplets of the format [[head(s)], [relation_id, relation], [tail(s)]]
    for i in range(len(head)):
        head_index = head[i]
        tail_index = tail[i]
        relation_id = relation_id[i]
        relation_text = relation_text[i]

        head_entities = names[head_index]
        tail_entities = names[tail_index]
        relation = [relation_id, relation_text]
        triplets.append([head_entities, relation, tail_entities])
    
    return triplets