In [2]:
import xml
import xml.etree.ElementTree as ET
import os
from os import path
import glob
from collections import defaultdict, OrderedDict
import spacy
import json

from utils import get_ent_info, get_clusters_from_xml
from transformers import BertTokenizer

I0625 16:06:51.128863 140383852373824 file_utils.py:32] TensorFlow version 2.0.0 available.
I0625 16:06:51.129541 140383852373824 file_utils.py:39] PyTorch version 1.2.0 available.
I0625 16:06:51.499200 140383852373824 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .


In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

I0625 16:07:19.624500 140383852373824 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/shtoshni/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1


In [4]:
def tokenize_string(string):
    string = string.strip()
    if string == "":
        return []
    else:
        tokenized_sent = tokenizer.tokenize(string)
        return tokenized_sent

In [5]:
def tokenize_doc(doc_str, ent_list):
    """Tokenizes a document given in string format.
    doc_str: Document string
    ent_list: List of entities with each entry being ((span_start, span_end), ent_id) 
        where spans are provided in the character space.
    
    Returns:
    tokenized_doc: List of tokens
    ent_id_to_token_spans: Entity ID to span indices in token space.
    """
    tokenized_doc = []
    token_counter = 0  
    char_offset = 0  # Till what point has the document been processed
    ent_id_to_token_spans = OrderedDict()

    for (span_start, span_end), ent_id in ent_list:
        # Tokenize the string before the span and after the last span
        before_span_str = doc_str[char_offset: span_start]
        before_span_tokens = tokenize_string(before_span_str)
        tokenized_doc.extend(before_span_tokens)
        token_counter += len(before_span_tokens)

        # Tokenize the span
        span_tokens = tokenize_string(doc_str[span_start: span_end])
        ent_id_to_token_spans[ent_id] = (token_counter, token_counter + len(span_tokens))
        tokenized_doc.extend(span_tokens)
        char_offset = span_end
        token_counter += len(span_tokens)

    # Add the tokens after the last span
    rem_doc = doc_str[char_offset:]
    rem_tokens = tokenize_string(rem_doc)
    token_counter += len(rem_tokens)

    tokenized_doc.extend(rem_tokens)
    return tokenized_doc, ent_id_to_token_spans

In [6]:
def get_cluster_spans(clusters_ent_id, ent_id_to_token_spans):
    clusters = []
    for cluster in clusters_ent_id:
        cluster_spans = []
        for ent_id in cluster:
            cluster_spans.append(ent_id_to_token_spans[ent_id])
        clusters.append(cluster_spans)
        
    return clusters

In [7]:
def get_dummy_speaker(tokenized_sents):
    speakers = []
    for sent in tokenized_sents:
        speakers.append(["spk1"] * len(sent))
    return speakers

In [8]:
def load_splits_file(list_file):
    return set([file_name.strip() for file_name in open(list_file).readlines()])

In [9]:
data_dir = "/home/shtoshni/Research/events/data/red/data/source"
source_files = glob.glob("{}/*/*".format(data_dir))

ann_dir = "/home/shtoshni/Research/events/data/red/data/mod_annotation"
ann_files = glob.glob("{}/*/*".format(ann_dir))

# Load the file splits
dev_list_file = "/home/shtoshni/Research/events/data/red/docs/dev.txt"
dev_set = load_splits_file(dev_list_file)

test_list_file = "/home/shtoshni/Research/events/data/red/docs/test.txt"
test_set = load_splits_file(test_list_file)

# Output directory
output_dir = "/home/shtoshni/Research/events/data/red/split-ref"

In [14]:
# Test ground
source_file = "/home/shtoshni/Research/events/data/red/data/source/deft/04debcc4da342dc971bdef4210fe468a.mpdf"
source_lines = open(source_file).readlines()
doc_str = "".join(source_lines)


# Read the annotation file
base_name = path.basename(source_file)
dir_name = path.basename(path.dirname(source_file))
red_file_name = path.join(dir_name, base_name)
print(red_file_name)

ann_file = path.join(path.join(ann_dir, dir_name), base_name + ".RED-Relation.gold.completed.xml")    
tree = ET.parse(ann_file)
root = tree.getroot()

# Get entity and cluster information from the annotation file
ent_map, ent_list = get_ent_info(root)
clusters_ent_id = get_clusters_from_xml(root, ent_map)

# Tokenize the doc
tokenized_doc, ent_id_to_token_spans = tokenize_doc(doc_str, ent_list)

# Break the document into sentences.
print(tokenized_doc)
print(ent_id_to_token_spans)

deft/04debcc4da342dc971bdef4210fe468a.mpdf
['<', 'do', '##c', 'id', '=', '"', '04', '##de', '##b', '##cc', '##4', '##da', '##34', '##2', '##d', '##c', '##9', '##7', '##1', '##b', '##de', '##f', '##42', '##10', '##fe', '##46', '##8', '##a', '"', '>', '<', 'headline', '>', '<', '/', 'headline', '>', '<', 'post', 'author', '=', '"', 'Dick', 'here', '"', 'date', '##time', '=', '"', '2008', '-', '01', '-', '11', '##T', '##12', ':', '18', ':', '00', '"', 'id', '=', '"', 'p', '##1', '"', '>', 'Don', "'", 't', 'order', 'anything', 'online', 'if', 'Amtrak', 'are', 'delivering', 'it', '-', 'here', "'", 's', 'my', 'experience', '.', 'Order', '##ed', 'a', '32', '"', 'TV', 'online', ',', 'cheaper', 'than', 'A', '##rgo', '##s', '-', 'who', 'didn', "'", 't', 'have', 'it', 'in', 'stock', '-', 'but', 'with', 'the', 'delivery', 'charge', 'the', 'cost', 'was', 'the', 'same', '.', 'Ad', '##vise', '##d', 'that', 'it', 'would', 'be', 'delivered', 'by', 'Amtrak', 'on', 'Tuesday', '.', 'Tuesday', 'came', 'and

In [35]:
import re

x = re.compile(r'## *')
if x.match('##'):
    print("yoyo")

yoyo


In [23]:
train_data = []
dev_data = []
test_data = []

for source_file in source_files:
    # Read the source doc
    source_lines = open(source_file).readlines()
    doc_str = "".join(source_lines)
    
    # ADDED NEW RULES
    doc_str = doc_str.replace('<', '~')
    doc_str = doc_str.replace('>', '^')
    
    # Read the annotation file
    base_name = path.basename(source_file)
    dir_name = path.basename(path.dirname(source_file))
    red_file_name = path.join(dir_name, base_name)
    print(red_file_name)
    
    ann_file = path.join(path.join(ann_dir, dir_name), base_name + ".RED-Relation.gold.completed.xml")    
    tree = ET.parse(ann_file)
    root = tree.getroot()
    
    # Get entity and cluster information from the annotation file
    ent_map, ent_list = get_ent_info(root)
    clusters_ent_id = get_clusters_from_xml(root, ent_map)
    
    # Tokenize the doc
    tokenized_doc, ent_id_to_token_spans = tokenize_doc(doc_str, ent_list)
    
    # Break the document into sentences.
    tokenized_sents = []
    tokenized_doc_str = " ".join(tokenized_doc)
    reproc_doc = spacy_nlp(tokenized_doc_str)
    
    # SWITCH BACK TO THE ORIGINAL TOKENS
    tokenized_doc_str = tokenized_doc_str.replace('~', '<')
    tokenized_doc_str = tokenized_doc_str.replace('^', '>')
    tokenized_doc = tokenized_doc_str.split()
    
    for sent in reproc_doc.sents:
        sent_text = sent.text
        # SWITCH BACK TO THE ORIGINAL TOKENS
        sent_text = sent_text.replace('~', '<')
        sent_text = sent_text.replace('^', '>')
        tokenized_sents.append(sent_text.split())
#         tokenized_sents.append(sent_text)

    cluster_spans = get_cluster_spans(clusters_ent_id, ent_id_to_token_spans)
    
    try:
        # Check the retokenized doc is same as tokenized doc
        retokenized_doc = []
        for sent in tokenized_sents:
            retokenized_doc.extend(sent)
        assert(tokenized_doc == retokenized_doc)
    except AssertionError:
        print(len(tokenized_doc))
        print(len(retokenized_doc))
        break

    doc_info = {}
    doc_info["doc_key"] = red_file_name
    doc_info["sentences"] = tokenized_sents
    doc_info["clusters"] = cluster_spans
    doc_info["speakers"] = get_dummy_speaker(tokenized_sents)
    
    file_name = path.join(dir_name, base_name)
    if red_file_name in dev_set:
        dev_data.append(doc_info)
    elif red_file_name in test_set:
        test_data.append(doc_info)
    else:
        train_data.append(doc_info)

deft/NYT_ENG_20130426.0143
deft/NYT_ENG_20131225.0200
deft/NYT_ENG_20130525.0040
deft/APW_ENG_20101231.0037
deft/7677d625b58ce649c8aeda2ff4a56389.mpdf
deft/NYT_ENG_20131029.0091
deft/NYT_ENG_20131003.0269
deft/362f9d9707c4da0c8068bc7034aae4b4.mpdf
deft/NYT_ENG_20130613.0153
deft/565fa81d640f451b20955887a43b3a23.mpdf
deft/aa003ea934a97bac86cee52b7122f1f8.mpdf
deft/5e3fbf49f8301654bb4954c0f1e386a9.mpdf
deft/NYT_ENG_20130619.0092
deft/17a2dc40635ec239e9e16d10b6dd45e8.mpdf
deft/AFP_ENG_20100414.0615
deft/d4698e3ad06f896058ade2e8f3a09577.mpdf
deft/dd0b65f632f64369c530f9bbb4b024b4.mpdf
deft/5c7ea2b51202d80ee37eba8a182afad3.mpdf
deft/NYT_ENG_20130703.0214
deft/2d2a4ddb1c8f4a669541704f9fb78472.mpdf
deft/NYT_ENG_20131128.0177
deft/635bde2afdaaf20a0bcdc3b5f79578c9.mpdf
deft/4798bc0e166fe93893bdf2d922f06258.mpdf
deft/648abb9000309b9807cc8b212c11254f.mpdf
deft/37b56b6dd846ad0dd6e8cd00ba2efaf4.mpdf
deft/af18d29036ab0a9f8cf2742a5a1b4804.mpdf
deft/ca2a6fbf721ca102c149ad6a90d5b00a.mpdf
deft/1d2911e09a

In [24]:
for split, data in zip(['train', 'dev', 'test'], [train_data, dev_data, test_data]):
    with open(path.join(output_dir, "{}.english.jsonlines".format(split)), 'w') as f:
        for instance in data:
            f.write(json.dumps(instance) + "\n")

In [32]:
# Get stats on train data
max_num_sentences = 0
num_chains = []
span_length = []
chain_length = []

for instance in train_data:
    max_num_sentences = max(max_num_sentences, len(instance["sentences"]))
    num_chains.append(len(instance["clusters"]))
    for chain in instance["clusters"]:
        chain_length.append(len(chain))
        for mention in chain:
            span_start, span_end = mention
            span_length.append(span_end - span_start)

In [33]:
print("Max sentences:", max_num_sentences)
print("Max chains:", max(num_chains))
print("Max span:", max(span_length))
print("Max chain length:", max(chain_length))


Max sentences: 258
Max chains: 60
Max span: 9
Max chain length: 71


In [27]:
import matplotlib.pyplot as plt

plt.hist(chain_length)

(array([196.,   7.,   3.,   5.,   3.,   1.,   0.,   1.,   0.,   1.]),
 array([ 2. ,  6.3, 10.6, 14.9, 19.2, 23.5, 27.8, 32.1, 36.4, 40.7, 45. ]),
 <a list of 10 Patch objects>)