In [9]:
from datasets import load_dataset
import pandas as pd

dataset = load_dataset('docred', trust_remote_code=True)
annotated = pd.DataFrame(dataset['train_annotated'])
distant = pd.DataFrame(dataset['train_distant'])

In [44]:
validation = pd.DataFrame(dataset['validation'])

for i in range(3):
    print(validation.iloc[i]['sents'])

[['Skai', 'TV', 'is', 'a', 'Greek', 'free', '-', 'to', '-', 'air', 'television', 'network', 'based', 'in', 'Piraeus', '.'], ['It', 'is', 'part', 'of', 'the', 'Skai', 'Group', ',', 'one', 'of', 'the', 'largest', 'media', 'groups', 'in', 'the', 'country', '.'], ['It', 'was', 'relaunched', 'in', 'its', 'present', 'form', 'on', '1st', 'of', 'April', '2006', 'in', 'the', 'Athens', 'metropolitan', 'area', ',', 'and', 'gradually', 'spread', 'its', 'coverage', 'nationwide', '.'], ['Besides', 'digital', 'terrestrial', 'transmission', ',', 'it', 'is', 'available', 'on', 'the', 'subscription', '-', 'based', 'encrypted', 'services', 'of', 'Nova', 'and', 'Cosmote', 'TV', '.'], ['Skai', 'TV', 'is', 'also', 'a', 'member', 'of', 'Digea', ',', 'a', 'consortium', 'of', 'private', 'television', 'networks', 'introducing', 'digital', 'terrestrial', 'transmission', 'in', 'Greece', '.'], ['At', 'launch', ',', 'Skai', 'TV', 'opted', 'for', 'dubbing', 'all', 'foreign', 'language', 'content', 'into', 'Greek', '

In [54]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from _RE import join_text
import pickle
import importlib

model_name = 'dslim/distilbert-NER'
tokenizer = AutoTokenizer.from_pretrained(model_name)
ner_model = AutoModelForTokenClassification.from_pretrained(model_name)

import prepare_data
from prepare_data import identify_entities, create_entity_pairs, get_keywords, move_to_root
importlib.reload(prepare_data)
from prepare_data import identify_entities, create_entity_pairs, get_keywords, move_to_root

ner_pipeline = pipeline('ner', model=ner_model, tokenizer=tokenizer)

def make_paragraphs(sents):
    paragraphs = []
    for word_list in sents:
        paragraph = ' '.join(word_list)
        paragraphs.append(paragraph)
    return paragraphs

datasets = [
    'validation',
    'test',
    'train_annotated',
    'train_distant'
]

custom_keywords = get_keywords()
train_test_data = {}

for datatype in datasets:
    print('Now doing:', datatype)
    ds = pd.DataFrame(dataset[datatype])
    context_and_pairs = []
    length = len(ds)
    length = 10

    for i in range(length):
        print(f"{(i/length)*100:.3f}%")
        elems = {}
        sents = ds.iloc[i]['sents']
        paragraphs = make_paragraphs(sents)
        entities = identify_entities(paragraphs, custom_keywords)
        pairs = create_entity_pairs(entities)
        elems['context'] = join_text(sents, fancy=False)
        elems['pairs'] = pairs
        context_and_pairs.append(elems)
    
    train_test_data[datatype] = context_and_pairs
    
move_to_root()
with open('knowledge_extraction/relation_extraction/data/docred_context_and_pairs.pkl', 'wb') as file:
    pickle.dump(context_and_pairs, file)


ImportError: cannot import name 'move_to_root' from 'prepare_data' (/Users/tiril/Documents/nuclear_repo/knowledge_extraction/relation_extraction/prepare_data.py)

In [21]:
# Copied directly for simplicity
custom_keywords = {'FUEL': ['u235', 'u238', 'uranium-235', 'uranium-238uranium compound', 'uranium oxide', 'uranium dioxideuranium fuel', 'nuclear fuel', 'mox', 'mox fuel', 'mixed oxide fuel', 'plutonium', 'pu239', 'plutonium-239', 'thorium', 'actinides', 'light water', 'heavy water'], 'FUEL_CYCLE': ['uranium oxide', 'uranium hexafluoride', 'hex', 'wet process', 'dry process', 'uranium enrichment', 'gas centrifuge', 'fuel rods', 'fuel assembly', 'low enriched fuel', 'leu', 'highly enriched fuel', 'heu', 'high assay low enriched uranium', 'haleu', 'triso', 'spent fuel', 'spent nuclear fuel', 'nuclear waste', 'radioactive waste', 'spent oxide fuel', 'spent reactor fuel', 'spent fuel pools', 'spent fuel ponds'], 'SMR_DESIGN': ['water-cooled', 'water cooled', 'light water reactor', 'lwr', 'heavy water reactor', 'hwr', 'boiling water reactor', 'pressurized water reactor', 'pwr', 'high temperature gas reactor', 'htgr', 'gas reactor', 'gas-cooled', 'gas cooled', 'pebble bed reactor', 'pbmrliquid metal cooled', 'liquid-metal-cooled', 'liquid metal-cooled', 'lead-bismuth', 'sodium cooled', 'sodium-cooledmolten salt reactor', 'molten salt', 'msrmicroreactor', 'micro reactormicro modular reactor', 'micro nuclear reactor'], 'REACTOR': ['nuclear reactor', 'nuclear power plant', 'nuclear power reactor', 'fast reactor'], 'SMR': ['smr', 'small modular reactor', 'small nuclear reactor'], 'POLITICAL': ['safety', 'security', 'nuclear regulation', 'proliferation', 'safeguards']}

In [2]:
print(dataset)

DatasetDict({
    validation: Dataset({
        features: ['title', 'sents', 'vertexSet', 'labels'],
        num_rows: 998
    })
    test: Dataset({
        features: ['title', 'sents', 'vertexSet', 'labels'],
        num_rows: 1000
    })
    train_annotated: Dataset({
        features: ['title', 'sents', 'vertexSet', 'labels'],
        num_rows: 3053
    })
    train_distant: Dataset({
        features: ['title', 'sents', 'vertexSet', 'labels'],
        num_rows: 101873
    })
})


In [None]:
'''
Data Format:
(source: https://github.com/thunlp/DocRED/blob/master/data/README.md)
{
  'title',
  'sents':     [
                  [word in sent 0], # list of lists of words forming sentences
                  [word in sent 1]
               ]
  'vertexSet': [
                  [
                    { 'name': mention_name, # name of entity mention
                      'sent_id': mention in which sentence, --> index of the sentence where the mention occurs
                      'pos': postion of mention in a sentence,  --> start and end position (indices) of the mention in the sentence
                      'type': NER_type} #the NER type, e.g. PERSON, LOCATION
                    {anthor mention}
                  ], 
                  [anthoer entity]
                ]
  'labels':   [
                {
                  'h': idx of head entity in vertexSet,
                  't': idx of tail entity in vertexSet,
                  'r': relation,
                  'evidence': evidence sentences' id --> the sentences from which the relation is supported
                }
              ]
}'''

In [12]:
import importlib
import _RE
importlib.reload(_RE)
from _RE import join_text, make_triplets

sents = annotated.iloc[0]['sents']
text = join_text(sents, fancy=False)
print(text)

Zest Airways , Inc. operated as AirAsia Zest ( formerly Asian Spirit and Zest Air ) , was a low - cost airline based at the Ninoy Aquino International Airport in Pasay City , Metro Manila in the Philippines .
It operated scheduled domestic and international tourist services , mainly feeder services linking Manila and Cebu with 24 domestic destinations in support of the trunk route operations of other airlines .
In 2013 , the airline became an affiliate of Philippines AirAsia operating their brand separately .
Its main base was Ninoy Aquino International Airport , Manila .
The airline was founded as Asian Spirit , the first airline in the Philippines to be run as a cooperative .
On August 16 , 2013 , the Civil Aviation Authority of the Philippines ( CAAP ) , the regulating body of the Government of the Republic of the Philippines for civil aviation , suspended Zest Air flights until further notice because of safety issues .
Less than a year after AirAsia and Zest Air 's strategic allian

In [22]:
index = 6
labels = annotated.iloc[index]['labels']
vertexSet = annotated['vertexSet'][index]

'''print(annotated.columns)
for subset in vertexSet:
    print (subset)
for l in labels:
    print(l, labels[l])'''

triplet = make_triplets(vertexSet, labels)

for t in triplet:
    print(t)

# The first triplet here makes no sense

[['Mississippi River'], ['P131', 'located in the administrative territorial entity'], ['Illinois']]
[['Mississippi River'], ['P17', 'country'], ['United States']]
[['Madison County'], ['P131', 'located in the administrative territorial entity'], ['Illinois']]
[['Madison County'], ['P17', 'country'], ['United States']]
[['Illinois'], ['P206', 'located in or next to body of water'], ['Mississippi River']]
[['Illinois'], ['P150', 'contains administrative territorial entity'], ['Madison County']]
[['Illinois'], ['P131', 'located in the administrative territorial entity'], ['United States']]
[['Illinois'], ['P17', 'country'], ['United States']]
[['United States'], ['P150', 'contains administrative territorial entity'], ['Illinois']]
[['United States'], ['P150', 'contains administrative territorial entity'], ['Missouri']]
[['Missouri'], ['P131', 'located in the administrative territorial entity'], ['United States']]
[['Missouri'], ['P17', 'country'], ['United States']]
[['Greater St. Louis',

In [68]:
import json

filenames = [
    #'docred_metadata/char2id.json',
    #'docred_metadata/ner2id.json',
    #'docred_metadata/rel2id.json',
    #'docred_metadata/word2id.json',
    'docred_metadata/rel_info.json'
    ]

for filename in filenames:
    with open(filename, 'r') as f:
        data = json.load(f)

    with open(filename, 'w') as f:
        json.dump(data, f, indent=4)

In [13]:
import json

with open ('docred_metadata/rel_info.json', 'r') as f:
    data = json.load(f)

print('Number of classes:', len(data))
for i, key in enumerate(data):
    print(f"{i+1}. {key}: {data[key]}")

Number of classes: 96
1. P6: head of government
2. P17: country
3. P19: place of birth
4. P20: place of death
5. P22: father
6. P25: mother
7. P26: spouse
8. P27: country of citizenship
9. P30: continent
10. P31: instance of
11. P35: head of state
12. P36: capital
13. P37: official language
14. P39: position held
15. P40: child
16. P50: author
17. P54: member of sports team
18. P57: director
19. P58: screenwriter
20. P69: educated at
21. P86: composer
22. P102: member of political party
23. P108: employer
24. P112: founded by
25. P118: league
26. P123: publisher
27. P127: owned by
28. P131: located in the administrative territorial entity
29. P136: genre
30. P137: operator
31. P140: religion
32. P150: contains administrative territorial entity
33. P155: follows
34. P156: followed by
35. P159: headquarters location
36. P161: cast member
37. P162: producer
38. P166: award received
39. P170: creator
40. P171: parent taxon
41. P172: ethnic group
42. P175: performer
43. P176: manufacturer
4