In [2]:
from datasets import load_dataset
dataset = load_dataset('docred', trust_remote_code=True)
print(dataset)

DatasetDict({
    validation: Dataset({
        features: ['title', 'sents', 'vertexSet', 'labels'],
        num_rows: 998
    })
    test: Dataset({
        features: ['title', 'sents', 'vertexSet', 'labels'],
        num_rows: 1000
    })
    train_annotated: Dataset({
        features: ['title', 'sents', 'vertexSet', 'labels'],
        num_rows: 3053
    })
    train_distant: Dataset({
        features: ['title', 'sents', 'vertexSet', 'labels'],
        num_rows: 101873
    })
})


In [3]:
import pandas as pd
from ipydatagrid import DataGrid
import itertools

ds = pd.DataFrame(dataset['train_annotated'])

grid = DataGrid(ds)
grid


DataGrid(auto_fit_params={'area': 'all', 'padding': 30, 'numCols': None}, corner_renderer=None, default_render…

In [18]:
def get_info(docred_instance):
    title = docred_instance['title']

    flattened_sents = [' '.join(sublist) for sublist in docred_instance['sents']]
    text = '\n'.join(flattened_sents)

    nested_entities = docred_instance['vertexSet']
    entities = list(itertools.chain(*nested_entities))
    
    head = docred_instance['labels']['head']
    tail = docred_instance['labels']['tail']
    r_id = docred_instance['labels']['relation_id']
    r_text = docred_instance['labels']['relation_text']
    evidence = docred_instance['labels']['evidence']

    return title, text, entities, head, tail, r_id, r_text, evidence

In [15]:
from transformers import AutoTokenizer, pipeline

model_name = 'dslim/distilbert-NER'
tokenizer = AutoTokenizer.from_pretrained(model_name)
ner_pipeline = pipeline('ner', model=model_name, tokenizer=tokenizer)

In [21]:
import importlib
import _NER
importlib.reload(_NER)
from _NER import merge_result, combine_entities

title, text, original_entities, _, _, _, _, _ = get_info(ds.iloc[4])
result = ner_pipeline(text)
merged_result = merge_result(result, model_name)
distilbert_entities = combine_entities(merged_result)

print('ORIGINAL:')
for e in original_entities:
    print(e)

print('DISTILBERT:')
for e in distilbert_entities:
    print(e)



ORIGINAL:
{'name': 'Ministry for Home Security', 'sent_id': 1, 'pos': [1, 5], 'type': 'ORG'}
{'name': 'Ministry of Home Security', 'sent_id': 0, 'pos': [1, 5], 'type': 'ORG'}
{'name': 'Ministry', 'sent_id': 2, 'pos': [1, 2], 'type': 'ORG'}
{'name': 'Ministry', 'sent_id': 4, 'pos': [1, 2], 'type': 'ORG'}
{'name': 'Ministry', 'sent_id': 7, 'pos': [7, 8], 'type': 'ORG'}
{'name': 'British', 'sent_id': 0, 'pos': [7, 8], 'type': 'LOC'}
{'name': '1939', 'sent_id': 0, 'pos': [12, 13], 'type': 'TIME'}
{'name': 'Second World War', 'sent_id': 0, 'pos': [29, 32], 'type': 'MISC'}
{'name': 'John Anderson', 'sent_id': 1, 'pos': [9, 11], 'type': 'PER'}
{'name': 'John Anderson', 'sent_id': 6, 'pos': [5, 7], 'type': 'PER'}
{'name': 'Women ’s Voluntary Service', 'sent_id': 2, 'pos': [27, 31], 'type': 'ORG'}
{'name': 'ARP', 'sent_id': 3, 'pos': [9, 10], 'type': 'ORG'}
{'name': 'Home Office', 'sent_id': 4, 'pos': [9, 11], 'type': 'ORG'}
{'name': 'Fire Guards', 'sent_id': 5, 'pos': [28, 30], 'type': 'ORG'}
